diff options
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 8 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 8 |
3 files changed, 10 insertions, 8 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 609b2a4..f45d467 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -67,10 +67,10 @@ class Column: return self.binary_op(exp.Mul, other) def __truediv__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) + return self.binary_op(exp.FloatDiv, other) def __div__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) + return self.binary_op(exp.FloatDiv, other) def __neg__(self) -> Column: return self.unary_op(exp.Neg) @@ -85,10 +85,10 @@ class Column: return self.inverse_binary_op(exp.Mul, other) def __rdiv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) + return self.inverse_binary_op(exp.FloatDiv, other) def __rtruediv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) + return self.inverse_binary_op(exp.FloatDiv, other) def __rmod__(self, other: ColumnOrLiteral) -> Column: return self.inverse_binary_op(exp.Mod, other) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 93ca45a..32ee927 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -260,7 +260,7 @@ class DataFrame: @classmethod def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: expression = item.expression if isinstance(item, DataFrame) else item - return [Column(x) for x in expression.find(exp.Select).expressions] + return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] @classmethod def _create_hash_from_expression(cls, expression: exp.Select): diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 8f24746..3c98f42 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -954,10 +954,12 @@ def array_join( col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None ) -> Column: if null_replacement is not None: - return Column.invoke_anonymous_function( - col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement) + return Column.invoke_expression_over_column( + col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement) ) - return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) + return Column.invoke_expression_over_column( + col, expression.ArrayJoin, expression=lit(delimiter) + ) def concat(*cols: ColumnOrName) -> Column: |