From 8d36f5966675e23bee7026ba37ae0647fbf47300 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Apr 2024 10:11:53 +0200 Subject: Merging upstream version 23.7.0. Signed-off-by: Daniel Baumann --- sqlglot/dataframe/sql/dataframe.py | 13 +++++++------ sqlglot/dataframe/sql/functions.py | 14 +++++++------- sqlglot/dataframe/sql/session.py | 11 ++++++++++- 3 files changed, 24 insertions(+), 14 deletions(-) (limited to 'sqlglot/dataframe/sql') diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 0bacbf9..8316c36 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -18,8 +18,6 @@ from sqlglot.dataframe.sql.transforms import replace_id_value from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window from sqlglot.helper import ensure_list, object_to_dict, seq_get -from sqlglot.optimizer import optimize as optimize_func -from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ( @@ -121,7 +119,9 @@ class DataFrame: self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] ) replacement_mapping[old_name_id] = new_hashed_id - expression = expression.transform(replace_id_value, replacement_mapping) + expression = expression.transform(replace_id_value, replacement_mapping).assert_is( + exp.Select + ) return expression def _create_cte_from_expression( @@ -306,11 +306,12 @@ class DataFrame: replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} for expression_type, select_expression in select_expressions: - select_expression = select_expression.transform(replace_id_value, replacement_mapping) + select_expression = select_expression.transform( + replace_id_value, replacement_mapping + ).assert_is(exp.Select) if optimize: - quote_identifiers(select_expression, dialect=dialect) select_expression = t.cast( - exp.Select, optimize_func(select_expression, dialect=dialect) + exp.Select, self.spark._optimize(select_expression, dialect=dialect) ) select_expression = df._replace_cte_names_with_hashes(select_expression) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index db5201f..b4dd2c6 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -184,7 +184,7 @@ def floor(col: ColumnOrName) -> Column: def log10(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Log10) + return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col) def log1p(col: ColumnOrName) -> Column: @@ -192,7 +192,7 @@ def log1p(col: ColumnOrName) -> Column: def log2(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Log2) + return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col) def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: @@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column: def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "CORR", col2) + return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2) def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_POP", col2) + return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2) def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2) + return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2) def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: @@ -971,10 +971,10 @@ def array_join( ) -> Column: if null_replacement is not None: return Column.invoke_expression_over_column( - col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement) + col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement) ) return Column.invoke_expression_over_column( - col, expression.ArrayJoin, expression=lit(delimiter) + col, expression.ArrayToString, expression=lit(delimiter) ) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index bfc022b..4e47aaa 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -12,6 +12,8 @@ from sqlglot.dataframe.sql.readwriter import DataFrameReader from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input from sqlglot.helper import classproperty +from sqlglot.optimizer import optimize +from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput @@ -104,8 +106,15 @@ class SparkSession: sel_expression = exp.Select(**select_kwargs) return DataFrame(self, sel_expression) + def _optimize( + self, expression: exp.Expression, dialect: t.Optional[Dialect] = None + ) -> exp.Expression: + dialect = dialect or self.dialect + quote_identifiers(expression, dialect=dialect) + return optimize(expression, dialect=dialect) + def sql(self, sqlQuery: str) -> DataFrame: - expression = sqlglot.parse_one(sqlQuery, read=self.dialect) + expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) if isinstance(expression, exp.Select): df = DataFrame(self, expression) df = df._convert_leaf_to_cte() -- cgit v1.2.3