diff options
Diffstat (limited to 'sqlglot/dataframe')
-rw-r--r-- | sqlglot/dataframe/README.md | 18 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 5 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 1 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 14 |
5 files changed, 25 insertions, 15 deletions
diff --git a/sqlglot/dataframe/README.md b/sqlglot/dataframe/README.md index 02179f4..86fdc4b 100644 --- a/sqlglot/dataframe/README.md +++ b/sqlglot/dataframe/README.md @@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will ## Instructions * [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library. * Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`. -* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`. +* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>, dialect="spark")`. * The column structure can be defined the following ways: * Dictionary where the keys are column names and values are string of the Spark SQL type name. * Ex: `{'cola': 'string', 'colb': 'int'}` @@ -33,12 +33,16 @@ import sqlglot from sqlglot.dataframe.sql.session import SparkSession from sqlglot.dataframe.sql import functions as F -sqlglot.schema.add_table('employee', { - 'employee_id': 'INT', - 'fname': 'STRING', - 'lname': 'STRING', - 'age': 'INT', -}) # Register the table structure prior to reading from the table +sqlglot.schema.add_table( + 'employee', + { + 'employee_id': 'INT', + 'fname': 'STRING', + 'lname': 'STRING', + 'age': 'INT', + }, + dialect="spark", +) # Register the table structure prior to reading from the table spark = SparkSession() diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index a8b89d1..f4cfeba 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -5,6 +5,7 @@ import typing as t import sqlglot from sqlglot import expressions as exp from sqlglot.dataframe.sql.types import DataType +from sqlglot.dialects import Spark from sqlglot.helper import flatten, is_iterable if t.TYPE_CHECKING: @@ -22,6 +23,10 @@ class Column: expression = sqlglot.maybe_parse(expression, dialect="spark") if expression is None: raise ValueError(f"Could not parse {expression}") + + if isinstance(expression, exp.Column): + expression.transform(Spark.normalize_identifier, copy=False) + self.expression: exp.Expression = expression def __repr__(self): diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 3fc9232..64cceea 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -316,6 +316,7 @@ class DataFrame: expression.alias_or_name: expression.type.sql("spark") for expression in select_expression.expressions }, + dialect="spark", ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 75feba7..4eec782 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import expressions as exp from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join +from sqlglot.dialects import Spark from sqlglot.helper import ensure_list NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) @@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ for expression in expressions: identifiers = expression.find_all(exp.Identifier) for identifier in identifiers: + Spark.normalize_identifier(identifier) replace_alias_name_with_cte_name(spark, expression_context, identifier) replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index cc2f181..9d87d4a 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,8 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.helper import object_to_dict, should_identify +from sqlglot.dialects import Spark +from sqlglot.helper import object_to_dict if t.TYPE_CHECKING: from sqlglot.dataframe.sql.dataframe import DataFrame @@ -18,17 +19,14 @@ class DataFrameReader: def table(self, tableName: str) -> DataFrame: from sqlglot.dataframe.sql.dataframe import DataFrame - sqlglot.schema.add_table(tableName) + sqlglot.schema.add_table(tableName, dialect="spark") return DataFrame( self.spark, exp.Select() - .from_(tableName) + .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) .select( - *( - column if should_identify(column, "safe") else f'"{column}"' - for column in sqlglot.schema.column_names(tableName) - ) + *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) ), ) @@ -73,7 +71,7 @@ class DataFrameWriter: ) df = self._df.copy(output_expression_container=output_expression_container) if self._by_name: - columns = sqlglot.schema.column_names(tableName, only_visible=True) + columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") df = df._convert_leaf_to_cte().select(*columns) return self.copy(_df=df) |