summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:11:53 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-08 08:12:02 +0000
commit8d36f5966675e23bee7026ba37ae0647fbf47300 (patch)
treedf4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot/dataframe/sql
parentReleasing debian version 22.2.0-1. (diff)
downloadsqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz
sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/dataframe.py13
-rw-r--r--sqlglot/dataframe/sql/functions.py14
-rw-r--r--sqlglot/dataframe/sql/session.py11
3 files changed, 24 insertions, 14 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 0bacbf9..8316c36 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -18,8 +18,6 @@ from sqlglot.dataframe.sql.transforms import replace_id_value
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
-from sqlglot.optimizer import optimize as optimize_func
-from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@@ -121,7 +119,9 @@ class DataFrame:
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
)
replacement_mapping[old_name_id] = new_hashed_id
- expression = expression.transform(replace_id_value, replacement_mapping)
+ expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
+ exp.Select
+ )
return expression
def _create_cte_from_expression(
@@ -306,11 +306,12 @@ class DataFrame:
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
for expression_type, select_expression in select_expressions:
- select_expression = select_expression.transform(replace_id_value, replacement_mapping)
+ select_expression = select_expression.transform(
+ replace_id_value, replacement_mapping
+ ).assert_is(exp.Select)
if optimize:
- quote_identifiers(select_expression, dialect=dialect)
select_expression = t.cast(
- exp.Select, optimize_func(select_expression, dialect=dialect)
+ exp.Select, self.spark._optimize(select_expression, dialect=dialect)
)
select_expression = df._replace_cte_names_with_hashes(select_expression)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index db5201f..b4dd2c6 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -184,7 +184,7 @@ def floor(col: ColumnOrName) -> Column:
def log10(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Log10)
+ return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col)
def log1p(col: ColumnOrName) -> Column:
@@ -192,7 +192,7 @@ def log1p(col: ColumnOrName) -> Column:
def log2(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.Log2)
+ return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col)
def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column:
@@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column:
def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "CORR", col2)
+ return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2)
def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
+ return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2)
def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
+ return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
@@ -971,10 +971,10 @@ def array_join(
) -> Column:
if null_replacement is not None:
return Column.invoke_expression_over_column(
- col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
+ col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
)
return Column.invoke_expression_over_column(
- col, expression.ArrayJoin, expression=lit(delimiter)
+ col, expression.ArrayToString, expression=lit(delimiter)
)
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index bfc022b..4e47aaa 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -12,6 +12,8 @@ from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
from sqlglot.helper import classproperty
+from sqlglot.optimizer import optimize
+from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
@@ -104,8 +106,15 @@ class SparkSession:
sel_expression = exp.Select(**select_kwargs)
return DataFrame(self, sel_expression)
+ def _optimize(
+ self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
+ ) -> exp.Expression:
+ dialect = dialect or self.dialect
+ quote_identifiers(expression, dialect=dialect)
+ return optimize(expression, dialect=dialect)
+
def sql(self, sqlQuery: str) -> DataFrame:
- expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
+ expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect))
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()