diff options
Diffstat (limited to 'sqlglot/dataframe/sql/readwriter.py')
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 9d87d4a..0804486 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.dialects import Spark from sqlglot.helper import object_to_dict if t.TYPE_CHECKING: @@ -18,15 +17,25 @@ class DataFrameReader: def table(self, tableName: str) -> DataFrame: from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.session import SparkSession - sqlglot.schema.add_table(tableName, dialect="spark") + sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) return DataFrame( self.spark, exp.Select() - .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) + .from_( + exp.to_table(tableName, dialect=SparkSession().dialect).transform( + SparkSession().dialect.normalize_identifier + ) + ) .select( - *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) + *( + column + for column in sqlglot.schema.column_names( + tableName, dialect=SparkSession().dialect + ) + ) ), ) @@ -63,6 +72,8 @@ class DataFrameWriter: return self.copy(by_name=True) def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: + from sqlglot.dataframe.sql.session import SparkSession + output_expression_container = exp.Insert( **{ "this": exp.to_table(tableName), @@ -71,7 +82,9 @@ class DataFrameWriter: ) df = self._df.copy(output_expression_container=output_expression_container) if self._by_name: - columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") + columns = sqlglot.schema.column_names( + tableName, only_visible=True, dialect=SparkSession().dialect + ) df = df._convert_leaf_to_cte().select(*columns) return self.copy(_df=df) |