summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/column.py8
-rw-r--r--sqlglot/dataframe/sql/dataframe.py2
-rw-r--r--sqlglot/dataframe/sql/functions.py8
3 files changed, 10 insertions, 8 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index 609b2a4..f45d467 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -67,10 +67,10 @@ class Column:
return self.binary_op(exp.Mul, other)
def __truediv__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Div, other)
+ return self.binary_op(exp.FloatDiv, other)
def __div__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Div, other)
+ return self.binary_op(exp.FloatDiv, other)
def __neg__(self) -> Column:
return self.unary_op(exp.Neg)
@@ -85,10 +85,10 @@ class Column:
return self.inverse_binary_op(exp.Mul, other)
def __rdiv__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Div, other)
+ return self.inverse_binary_op(exp.FloatDiv, other)
def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Div, other)
+ return self.inverse_binary_op(exp.FloatDiv, other)
def __rmod__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Mod, other)
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 93ca45a..32ee927 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -260,7 +260,7 @@ class DataFrame:
@classmethod
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
expression = item.expression if isinstance(item, DataFrame) else item
- return [Column(x) for x in expression.find(exp.Select).expressions]
+ return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
@classmethod
def _create_hash_from_expression(cls, expression: exp.Select):
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 8f24746..3c98f42 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -954,10 +954,12 @@ def array_join(
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
) -> Column:
if null_replacement is not None:
- return Column.invoke_anonymous_function(
- col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
+ return Column.invoke_expression_over_column(
+ col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
)
- return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
+ return Column.invoke_expression_over_column(
+ col, expression.ArrayJoin, expression=lit(delimiter)
+ )
def concat(*cols: ColumnOrName) -> Column: