summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/readwriter.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dataframe/sql/readwriter.py
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql/readwriter.py')
-rw-r--r--sqlglot/dataframe/sql/readwriter.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index 9d87d4a..0804486 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.dialects import Spark
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
@@ -18,15 +17,25 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
- sqlglot.schema.add_table(tableName, dialect="spark")
+ sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
return DataFrame(
self.spark,
exp.Select()
- .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
+ .from_(
+ exp.to_table(tableName, dialect=SparkSession().dialect).transform(
+ SparkSession().dialect.normalize_identifier
+ )
+ )
.select(
- *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
+ *(
+ column
+ for column in sqlglot.schema.column_names(
+ tableName, dialect=SparkSession().dialect
+ )
+ )
),
)
@@ -63,6 +72,8 @@ class DataFrameWriter:
return self.copy(by_name=True)
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ from sqlglot.dataframe.sql.session import SparkSession
+
output_expression_container = exp.Insert(
**{
"this": exp.to_table(tableName),
@@ -71,7 +82,9 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
- columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
+ columns = sqlglot.schema.column_names(
+ tableName, only_visible=True, dialect=SparkSession().dialect
+ )
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)