summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/readwriter.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/readwriter.py')
-rw-r--r--sqlglot/dataframe/sql/readwriter.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index cc2f181..9d87d4a 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,8 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.helper import object_to_dict, should_identify
+from sqlglot.dialects import Spark
+from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
@@ -18,17 +19,14 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
- sqlglot.schema.add_table(tableName)
+ sqlglot.schema.add_table(tableName, dialect="spark")
return DataFrame(
self.spark,
exp.Select()
- .from_(tableName)
+ .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
.select(
- *(
- column if should_identify(column, "safe") else f'"{column}"'
- for column in sqlglot.schema.column_names(tableName)
- )
+ *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
),
)
@@ -73,7 +71,7 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
- columns = sqlglot.schema.column_names(tableName, only_visible=True)
+ columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)