diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 5 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 8 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 18 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 8 |
4 files changed, 24 insertions, 15 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f9e1c5b..22075e9 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -260,7 +260,10 @@ class Column: """ if isinstance(dataType, DataType): dataType = dataType.simpleString() - new_expression = exp.Cast(this=self.column_expression, to=dataType) + new_expression = exp.Cast( + this=self.column_expression, + to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore + ) return Column(new_expression) def startswith(self, value: t.Union[str, Column]) -> Column: diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 40cd6c9..548c322 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -314,7 +314,13 @@ class DataFrame: replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore cache_table_name ) - sqlglot.schema.add_table(cache_table_name, select_expression.named_selects) + sqlglot.schema.add_table( + cache_table_name, + { + expression.alias_or_name: expression.type.name + for expression in select_expression.expressions + }, + ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index dbfb06f..1ee361a 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column: def decode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_anonymous_function(col, "DECODE", lit(charset)) + return Column.invoke_expression_over_column( + col, glotexp.Decode, charset=glotexp.Literal.string(charset) + ) def encode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_anonymous_function(col, "ENCODE", lit(charset)) + return Column.invoke_expression_over_column( + col, glotexp.Encode, charset=glotexp.Literal.string(charset) + ) def format_number(col: ColumnOrName, d: int) -> Column: @@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column: def hex(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "HEX") + return Column.invoke_expression_over_column(col, glotexp.Hex) def unhex(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "UNHEX") + return Column.invoke_expression_over_column(col, glotexp.Unhex) def length(col: ColumnOrName) -> Column: @@ -939,11 +943,7 @@ def array_join( def concat(*cols: ColumnOrName) -> Column: - if len(cols) == 1: - return Column.invoke_anonymous_function(cols[0], "CONCAT") - return Column.invoke_anonymous_function( - cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]] - ) + return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols) def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 8cb16ef..c4a22c6 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -88,14 +88,14 @@ class SparkSession: "expressions": sel_columns, "from": exp.From( expressions=[ - exp.Subquery( - this=exp.Values(expressions=data_expressions), + exp.Values( + expressions=data_expressions, alias=exp.TableAlias( this=exp.to_identifier(self._auto_incrementing_name), columns=[exp.to_identifier(col_name) for col_name in column_mapping], ), - ) - ] + ), + ], ), } |