diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dataframe/sql/readwriter.py | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql/readwriter.py')
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 9d87d4a..0804486 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.dialects import Spark from sqlglot.helper import object_to_dict if t.TYPE_CHECKING: @@ -18,15 +17,25 @@ class DataFrameReader: def table(self, tableName: str) -> DataFrame: from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.session import SparkSession - sqlglot.schema.add_table(tableName, dialect="spark") + sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) return DataFrame( self.spark, exp.Select() - .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) + .from_( + exp.to_table(tableName, dialect=SparkSession().dialect).transform( + SparkSession().dialect.normalize_identifier + ) + ) .select( - *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) + *( + column + for column in sqlglot.schema.column_names( + tableName, dialect=SparkSession().dialect + ) + ) ), ) @@ -63,6 +72,8 @@ class DataFrameWriter: return self.copy(by_name=True) def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: + from sqlglot.dataframe.sql.session import SparkSession + output_expression_container = exp.Insert( **{ "this": exp.to_table(tableName), @@ -71,7 +82,9 @@ class DataFrameWriter: ) df = self._df.copy(output_expression_container=output_expression_container) if self._by_name: - columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") + columns = sqlglot.schema.column_names( + tableName, only_visible=True, dialect=SparkSession().dialect + ) df = df._convert_leaf_to_cte().select(*columns) return self.copy(_df=df) |