diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dataframe/sql/dataframe.py | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql/dataframe.py')
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 34 |
1 files changed, 27 insertions, 7 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 64cceea..f515608 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -1,12 +1,13 @@ from __future__ import annotations import functools +import logging import typing as t import zlib from copy import copy import sqlglot -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.group import GroupedData @@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window from sqlglot.helper import ensure_list, object_to_dict, seq_get from sqlglot.optimizer import optimize as optimize_func +from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ( @@ -27,7 +29,9 @@ if t.TYPE_CHECKING: OutputExpressionContainer, ) from sqlglot.dataframe.sql.session import SparkSession + from sqlglot.dialects.dialect import DialectType +logger = logging.getLogger("sqlglot") JOIN_HINTS = { "BROADCAST", @@ -264,7 +268,9 @@ class DataFrame: @classmethod def _create_hash_from_expression(cls, expression: exp.Expression) -> str: - value = expression.sql(dialect="spark").encode("utf-8") + from sqlglot.dataframe.sql.session import SparkSession + + value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") return f"t{zlib.crc32(value)}"[:6] def _get_select_expressions( @@ -291,7 +297,15 @@ class DataFrame: select_expressions.append(expression_select_pair) # type: ignore return select_expressions - def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: + def sql( + self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs + ) -> t.List[str]: + from sqlglot.dataframe.sql.session import SparkSession + + if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: + logger.warning( + f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." + ) df = self._resolve_pending_hints() select_expressions = df._get_select_expressions() output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] @@ -299,7 +313,10 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = t.cast(exp.Select, optimize_func(select_expression)) + quote_identifiers(select_expression) + select_expression = t.cast( + exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) + ) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: @@ -313,10 +330,12 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.sql("spark") + expression.alias_or_name: expression.type.sql( + dialect=SparkSession().dialect + ) for expression in select_expression.expressions }, - dialect="spark", + dialect=SparkSession().dialect, ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ @@ -345,7 +364,8 @@ class DataFrame: output_expressions.append(expression) return [ - expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) + for expression in output_expressions ] def copy(self, **kwargs) -> DataFrame: |