diff options
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 39 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 2 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 4 |
3 files changed, 19 insertions, 26 deletions
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index f515608..68d36fe 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -297,27 +297,26 @@ class DataFrame: select_expressions.append(expression_select_pair) # type: ignore return select_expressions - def sql( - self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs - ) -> t.List[str]: + def sql(self, dialect: DialectType = None, optimize: bool = True, **kwargs) -> t.List[str]: from sqlglot.dataframe.sql.session import SparkSession - if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: - logger.warning( - f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." - ) + dialect = Dialect.get_or_raise(dialect or SparkSession().dialect) + df = self._resolve_pending_hints() select_expressions = df._get_select_expressions() output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} + for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - quote_identifiers(select_expression) + quote_identifiers(select_expression, dialect=dialect) select_expression = t.cast( - exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) + exp.Select, optimize_func(select_expression, dialect=dialect) ) + select_expression = df._replace_cte_names_with_hashes(select_expression) + expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: cache_table_name = df._create_hash_from_expression(select_expression) @@ -330,13 +329,12 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.sql( - dialect=SparkSession().dialect - ) + expression.alias_or_name: expression.type.sql(dialect=dialect) for expression in select_expression.expressions }, - dialect=SparkSession().dialect, + dialect=dialect, ) + cache_storage_level = select_expression.args["cache_storage_level"] options = [ exp.Literal.string("storageLevel"), @@ -345,6 +343,7 @@ class DataFrame: expression = exp.Cache( this=cache_table, expression=select_expression, lazy=True, options=options ) + # We will drop the "view" if it exists before running the cache table output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) elif expression_type == exp.Create: @@ -355,18 +354,17 @@ class DataFrame: select_without_ctes = select_expression.copy() select_without_ctes.set("with", None) expression.set("expression", select_without_ctes) + if select_expression.ctes: expression.set("with", exp.With(expressions=select_expression.ctes)) elif expression_type == exp.Select: expression = select_expression else: raise ValueError(f"Invalid expression type: {expression_type}") + output_expressions.append(expression) - return [ - expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) - for expression in output_expressions - ] + return [expression.sql(dialect=dialect, **kwargs) for expression in output_expressions] def copy(self, **kwargs) -> DataFrame: return DataFrame(**object_to_dict(self, **kwargs)) @@ -542,12 +540,7 @@ class DataFrame: """ columns = self._ensure_and_normalize_cols(cols) pre_ordered_col_indexes = [ - x - for x in [ - i if isinstance(col.expression, exp.Ordered) else None - for i, col in enumerate(columns) - ] - if x is not None + i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered) ] if ascending is None: ascending = [True] * len(columns) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index a424ea4..6671c5b 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -306,7 +306,7 @@ def collect_list(col: ColumnOrName) -> Column: def collect_set(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.SetAgg) + return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg) def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index 531ee17..4a33ef9 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -28,7 +28,7 @@ class SparkSession: self.known_sequence_ids = set() self.name_to_sequence_id_mapping = defaultdict(list) self.incrementing_id = 1 - self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() + self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT) def __new__(cls, *args, **kwargs) -> SparkSession: if cls._instance is None: @@ -182,7 +182,7 @@ class SparkSession: def getOrCreate(self) -> SparkSession: spark = SparkSession() - spark.dialect = Dialect.get_or_raise(self.dialect)() + spark.dialect = Dialect.get_or_raise(self.dialect) return spark @classproperty |