summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/sql/column.py5
-rw-r--r--sqlglot/dataframe/sql/dataframe.py8
-rw-r--r--sqlglot/dataframe/sql/functions.py18
-rw-r--r--sqlglot/dataframe/sql/session.py8
4 files changed, 24 insertions, 15 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index f9e1c5b..22075e9 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -260,7 +260,10 @@ class Column:
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
- new_expression = exp.Cast(this=self.column_expression, to=dataType)
+ new_expression = exp.Cast(
+ this=self.column_expression,
+ to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
+ )
return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column:
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 40cd6c9..548c322 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -314,7 +314,13 @@ class DataFrame:
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name
)
- sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
+ sqlglot.schema.add_table(
+ cache_table_name,
+ {
+ expression.alias_or_name: expression.type.name
+ for expression in select_expression.expressions
+ },
+ )
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index dbfb06f..1ee361a 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
def decode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Decode, charset=glotexp.Literal.string(charset)
+ )
def encode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Encode, charset=glotexp.Literal.string(charset)
+ )
def format_number(col: ColumnOrName, d: int) -> Column:
@@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "HEX")
+ return Column.invoke_expression_over_column(col, glotexp.Hex)
def unhex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "UNHEX")
+ return Column.invoke_expression_over_column(col, glotexp.Unhex)
def length(col: ColumnOrName) -> Column:
@@ -939,11 +943,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column:
- if len(cols) == 1:
- return Column.invoke_anonymous_function(cols[0], "CONCAT")
- return Column.invoke_anonymous_function(
- cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
- )
+ return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 8cb16ef..c4a22c6 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -88,14 +88,14 @@ class SparkSession:
"expressions": sel_columns,
"from": exp.From(
expressions=[
- exp.Subquery(
- this=exp.Values(expressions=data_expressions),
+ exp.Values(
+ expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
- )
- ]
+ ),
+ ],
),
}