summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/dataframe.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dataframe/sql/dataframe.py
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql/dataframe.py')
-rw-r--r--sqlglot/dataframe/sql/dataframe.py34
1 files changed, 27 insertions, 7 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 64cceea..f515608 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -1,12 +1,13 @@
from __future__ import annotations
import functools
+import logging
import typing as t
import zlib
from copy import copy
import sqlglot
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.group import GroupedData
@@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
+from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@@ -27,7 +29,9 @@ if t.TYPE_CHECKING:
OutputExpressionContainer,
)
from sqlglot.dataframe.sql.session import SparkSession
+ from sqlglot.dialects.dialect import DialectType
+logger = logging.getLogger("sqlglot")
JOIN_HINTS = {
"BROADCAST",
@@ -264,7 +268,9 @@ class DataFrame:
@classmethod
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
- value = expression.sql(dialect="spark").encode("utf-8")
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
def _get_select_expressions(
@@ -291,7 +297,15 @@ class DataFrame:
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
- def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
+ def sql(
+ self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
+ ) -> t.List[str]:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
+ logger.warning(
+ f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
+ )
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
@@ -299,7 +313,10 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- select_expression = t.cast(exp.Select, optimize_func(select_expression))
+ quote_identifiers(select_expression)
+ select_expression = t.cast(
+ exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
+ )
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@@ -313,10 +330,12 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.sql("spark")
+ expression.alias_or_name: expression.type.sql(
+ dialect=SparkSession().dialect
+ )
for expression in select_expression.expressions
},
- dialect="spark",
+ dialect=SparkSession().dialect,
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
@@ -345,7 +364,8 @@ class DataFrame:
output_expressions.append(expression)
return [
- expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
+ expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
+ for expression in output_expressions
]
def copy(self, **kwargs) -> DataFrame: