summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dataframe/sql
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/dataframe.py39
-rw-r--r--sqlglot/dataframe/sql/functions.py2
-rw-r--r--sqlglot/dataframe/sql/session.py4
3 files changed, 19 insertions, 26 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index f515608..68d36fe 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -297,27 +297,26 @@ class DataFrame:
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
- def sql(
- self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
- ) -> t.List[str]:
+ def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]:
from sqlglot.dataframe.sql.session import SparkSession
- if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
- logger.warning(
- f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
- )
+ dialect = Dialect.get_or_raise(dialect or SparkSession().dialect)
+
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
+
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- quote_identifiers(select_expression)
+ quote_identifiers(select_expression, dialect=dialect)
select_expression = t.cast(
- exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
+ exp.Select, optimize_func(select_expression, dialect=dialect)
)
+
select_expression = df._replace_cte_names_with_hashes(select_expression)
+
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
cache_table_name = df._create_hash_from_expression(select_expression)
@@ -330,13 +329,12 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.sql(
- dialect=SparkSession().dialect
- )
+ expression.alias_or_name: expression.type.sql(dialect=dialect)
for expression in select_expression.expressions
},
- dialect=SparkSession().dialect,
+ dialect=dialect,
)
+
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
@@ -345,6 +343,7 @@ class DataFrame:
expression = exp.Cache(
this=cache_table, expression=select_expression, lazy=True, options=options
)
+
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
@@ -355,18 +354,17 @@ class DataFrame:
select_without_ctes = select_expression.copy()
select_without_ctes.set("with", None)
expression.set("expression", select_without_ctes)
+
if select_expression.ctes:
expression.set("with", exp.With(expressions=select_expression.ctes))
elif expression_type == exp.Select:
expression = select_expression
else:
raise ValueError(f"Invalid expression type: {expression_type}")
+
output_expressions.append(expression)
- return [
- expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
- for expression in output_expressions
- ]
+ return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@@ -542,12 +540,7 @@ class DataFrame:
"""
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
- x
- for x in [
- i if isinstance(col.expression, exp.Ordered) else None
- for i, col in enumerate(columns)
- ]
- if x is not None
+ i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
]
if ascending is None:
ascending = [True] * len(columns)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index a424ea4..6671c5b 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -306,7 +306,7 @@ def collect_list(col: ColumnOrName) -> Column:
def collect_set(col: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(col, expression.SetAgg)
+ return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg)
def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 531ee17..4a33ef9 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -28,7 +28,7 @@ class SparkSession:
self.known_sequence_ids = set()
self.name_to_sequence_id_mapping = defaultdict(list)
self.incrementing_id = 1
- self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
+ self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)
def __new__(cls, *args, **kwargs) -> SparkSession:
if cls._instance is None:
@@ -182,7 +182,7 @@ class SparkSession:
def getOrCreate(self) -> SparkSession:
spark = SparkSession()
- spark.dialect = Dialect.get_or_raise(self.dialect)()
+ spark.dialect = Dialect.get_or_raise(self.dialect)
return spark
@classproperty