summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r--sqlglot/dataframe/README.md18
-rw-r--r--sqlglot/dataframe/sql/column.py5
-rw-r--r--sqlglot/dataframe/sql/dataframe.py1
-rw-r--r--sqlglot/dataframe/sql/normalize.py2
-rw-r--r--sqlglot/dataframe/sql/readwriter.py14
5 files changed, 25 insertions, 15 deletions
diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md
index 02179f4..86fdc4b 100644
--- a/sqlglot/dataframe/README.md
+++ b/sqlglot/dataframe/README.md
@@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will
## Instructions
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
-* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
+* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>, dialect="spark")`.
* The column structure can be defined the following ways:
* Dictionary where the keys are column names and values are string of the Spark SQL type name.
* Ex: `{'cola': 'string', 'colb': 'int'}`
@@ -33,12 +33,16 @@ import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F
-sqlglot.schema.add_table('employee', {
- 'employee_id': 'INT',
- 'fname': 'STRING',
- 'lname': 'STRING',
- 'age': 'INT',
-}) # Register the table structure prior to reading from the table
+sqlglot.schema.add_table(
+ 'employee',
+ {
+ 'employee_id': 'INT',
+ 'fname': 'STRING',
+ 'lname': 'STRING',
+ 'age': 'INT',
+ },
+ dialect="spark",
+) # Register the table structure prior to reading from the table
spark = SparkSession()
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index a8b89d1..f4cfeba 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -5,6 +5,7 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
+from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING:
@@ -22,6 +23,10 @@ class Column:
expression = sqlglot.maybe_parse(expression, dialect="spark")
if expression is None:
raise ValueError(f"Could not parse {expression}")
+
+ if isinstance(expression, exp.Column):
+ expression.transform(Spark.normalize_identifier, copy=False)
+
self.expression: exp.Expression = expression
def __repr__(self):
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 3fc9232..64cceea 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -316,6 +316,7 @@ class DataFrame:
expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
+ dialect="spark",
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index 75feba7..4eec782 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -5,6 +5,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
+from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
@@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
+ Spark.normalize_identifier(identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index cc2f181..9d87d4a 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,8 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.helper import object_to_dict, should_identify
+from sqlglot.dialects import Spark
+from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
@@ -18,17 +19,14 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
- sqlglot.schema.add_table(tableName)
+ sqlglot.schema.add_table(tableName, dialect="spark")
return DataFrame(
self.spark,
exp.Select()
- .from_(tableName)
+ .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
.select(
- *(
- column if should_identify(column, "safe") else f'"{column}"'
- for column in sqlglot.schema.column_names(tableName)
- )
+ *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
),
)
@@ -73,7 +71,7 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
- columns = sqlglot.schema.column_names(tableName, only_visible=True)
+ columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)