summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/readwriter.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/readwriter.py')
-rw-r--r--sqlglot/dataframe/sql/readwriter.py16
1 files changed, 13 insertions, 3 deletions
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index 4830035..febc664 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -19,12 +19,19 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
- return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+ return DataFrame(
+ self.spark,
+ exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
+ )
class DataFrameWriter:
def __init__(
- self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ self,
+ df: DataFrame,
+ spark: t.Optional[SparkSession] = None,
+ mode: t.Optional[str] = None,
+ by_name: bool = False,
):
self._df = df
self._spark = spark or df.spark
@@ -33,7 +40,10 @@ class DataFrameWriter:
def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter(
- **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ **{
+ k[1:] if k.startswith("_") else k: v
+ for k, v in object_to_dict(self, **kwargs).items()
+ }
)
def sql(self, **kwargs) -> t.List[str]: