diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dataframe/sql | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r-- | sqlglot/dataframe/sql/column.py | 22 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/dataframe.py | 34 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 8 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/normalize.py | 3 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/readwriter.py | 23 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 65 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/window.py | 4 |
7 files changed, 116 insertions, 43 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index fcfd71e..3acf494 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -5,7 +5,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp from sqlglot.dataframe.sql.types import DataType -from sqlglot.dialects import Spark from sqlglot.helper import flatten, is_iterable if t.TYPE_CHECKING: @@ -15,19 +14,20 @@ if t.TYPE_CHECKING: class Column: def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + from sqlglot.dataframe.sql.session import SparkSession + if isinstance(expression, Column): expression = expression.expression # type: ignore elif expression is None or not isinstance(expression, (str, exp.Expression)): expression = self._lit(expression).expression # type: ignore - - expression = sqlglot.maybe_parse(expression, dialect="spark") + elif not isinstance(expression, exp.Column): + expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( + SparkSession().dialect.normalize_identifier, copy=False + ) if expression is None: raise ValueError(f"Could not parse {expression}") - if isinstance(expression, exp.Column): - expression.transform(Spark.normalize_identifier, copy=False) - - self.expression: exp.Expression = expression + self.expression: exp.Expression = expression # type: ignore def __repr__(self): return repr(self.expression) @@ -207,7 +207,9 @@ class Column: return Column(expression) def sql(self, **kwargs) -> str: - return self.expression.sql(**{"dialect": "spark", **kwargs}) + from sqlglot.dataframe.sql.session import SparkSession + + return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) def alias(self, name: str) -> Column: new_expression = exp.alias_(self.column_expression, name) @@ -264,9 +266,11 @@ class Column: Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string """ + from sqlglot.dataframe.sql.session import SparkSession + if isinstance(dataType, DataType): dataType = dataType.simpleString() - return Column(exp.cast(self.column_expression, dataType, dialect="spark")) + return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) def startswith(self, value: t.Union[str, Column]) -> Column: value = self._lit(value) if not isinstance(value, Column) else value diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 64cceea..f515608 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -1,12 +1,13 @@ from __future__ import annotations import functools +import logging import typing as t import zlib from copy import copy import sqlglot -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.group import GroupedData @@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window from sqlglot.helper import ensure_list, object_to_dict, seq_get from sqlglot.optimizer import optimize as optimize_func +from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ( @@ -27,7 +29,9 @@ if t.TYPE_CHECKING: OutputExpressionContainer, ) from sqlglot.dataframe.sql.session import SparkSession + from sqlglot.dialects.dialect import DialectType +logger = logging.getLogger("sqlglot") JOIN_HINTS = { "BROADCAST", @@ -264,7 +268,9 @@ class DataFrame: @classmethod def _create_hash_from_expression(cls, expression: exp.Expression) -> str: - value = expression.sql(dialect="spark").encode("utf-8") + from sqlglot.dataframe.sql.session import SparkSession + + value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") return f"t{zlib.crc32(value)}"[:6] def _get_select_expressions( @@ -291,7 +297,15 @@ class DataFrame: select_expressions.append(expression_select_pair) # type: ignore return select_expressions - def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: + def sql( + self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs + ) -> t.List[str]: + from sqlglot.dataframe.sql.session import SparkSession + + if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: + logger.warning( + f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." + ) df = self._resolve_pending_hints() select_expressions = df._get_select_expressions() output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] @@ -299,7 +313,10 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = t.cast(exp.Select, optimize_func(select_expression)) + quote_identifiers(select_expression) + select_expression = t.cast( + exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) + ) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: @@ -313,10 +330,12 @@ class DataFrame: sqlglot.schema.add_table( cache_table_name, { - expression.alias_or_name: expression.type.sql("spark") + expression.alias_or_name: expression.type.sql( + dialect=SparkSession().dialect + ) for expression in select_expression.expressions }, - dialect="spark", + dialect=SparkSession().dialect, ) cache_storage_level = select_expression.args["cache_storage_level"] options = [ @@ -345,7 +364,8 @@ class DataFrame: output_expressions.append(expression) return [ - expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions + expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) + for expression in output_expressions ] def copy(self, **kwargs) -> DataFrame: diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 4002cfe..d0ae50c 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -368,9 +368,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - if ignorenulls is not None: - return Column.invoke_anonymous_function(col, "FIRST", ignorenulls) - return Column.invoke_anonymous_function(col, "FIRST") + return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls) def grouping_id(*cols: ColumnOrName) -> Column: @@ -394,9 +392,7 @@ def isnull(col: ColumnOrName) -> Column: def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - if ignorenulls is not None: - return Column.invoke_anonymous_function(col, "LAST", ignorenulls) - return Column.invoke_anonymous_function(col, "LAST") + return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls) def monotonically_increasing_id() -> Column: diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py index 4eec782..f68bacb 100644 --- a/sqlglot/dataframe/sql/normalize.py +++ b/sqlglot/dataframe/sql/normalize.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import expressions as exp from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join -from sqlglot.dialects import Spark from sqlglot.helper import ensure_list NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) @@ -20,7 +19,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[ for expression in expressions: identifiers = expression.find_all(exp.Identifier) for identifier in identifiers: - Spark.normalize_identifier(identifier) + identifier.transform(spark.dialect.normalize_identifier) replace_alias_name_with_cte_name(spark, expression_context, identifier) replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index 9d87d4a..0804486 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,6 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.dialects import Spark from sqlglot.helper import object_to_dict if t.TYPE_CHECKING: @@ -18,15 +17,25 @@ class DataFrameReader: def table(self, tableName: str) -> DataFrame: from sqlglot.dataframe.sql.dataframe import DataFrame + from sqlglot.dataframe.sql.session import SparkSession - sqlglot.schema.add_table(tableName, dialect="spark") + sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) return DataFrame( self.spark, exp.Select() - .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier)) + .from_( + exp.to_table(tableName, dialect=SparkSession().dialect).transform( + SparkSession().dialect.normalize_identifier + ) + ) .select( - *(column for column in sqlglot.schema.column_names(tableName, dialect="spark")) + *( + column + for column in sqlglot.schema.column_names( + tableName, dialect=SparkSession().dialect + ) + ) ), ) @@ -63,6 +72,8 @@ class DataFrameWriter: return self.copy(by_name=True) def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: + from sqlglot.dataframe.sql.session import SparkSession + output_expression_container = exp.Insert( **{ "this": exp.to_table(tableName), @@ -71,7 +82,9 @@ class DataFrameWriter: ) df = self._df.copy(output_expression_container=output_expression_container) if self._by_name: - columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark") + columns = sqlglot.schema.column_names( + tableName, only_visible=True, dialect=SparkSession().dialect + ) df = df._convert_leaf_to_cte().select(*columns) return self.copy(_df=df) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index b883359..531ee17 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -5,31 +5,35 @@ import uuid from collections import defaultdict import sqlglot -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.readwriter import DataFrameReader from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input +from sqlglot.helper import classproperty if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput class SparkSession: - known_ids: t.ClassVar[t.Set[str]] = set() - known_branch_ids: t.ClassVar[t.Set[str]] = set() - known_sequence_ids: t.ClassVar[t.Set[str]] = set() - name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) + DEFAULT_DIALECT = "spark" + _instance = None def __init__(self): - self.incrementing_id = 1 - - def __getattr__(self, name: str) -> SparkSession: - return self - - def __call__(self, *args, **kwargs) -> SparkSession: - return self + if not hasattr(self, "known_ids"): + self.known_ids = set() + self.known_branch_ids = set() + self.known_sequence_ids = set() + self.name_to_sequence_id_mapping = defaultdict(list) + self.incrementing_id = 1 + self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() + + def __new__(cls, *args, **kwargs) -> SparkSession: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance @property def read(self) -> DataFrameReader: @@ -101,7 +105,7 @@ class SparkSession: return DataFrame(self, sel_expression) def sql(self, sqlQuery: str) -> DataFrame: - expression = sqlglot.parse_one(sqlQuery, read="spark") + expression = sqlglot.parse_one(sqlQuery, read=self.dialect) if isinstance(expression, exp.Select): df = DataFrame(self, expression) df = df._convert_leaf_to_cte() @@ -149,3 +153,38 @@ class SparkSession: def _add_alias_to_mapping(self, name: str, sequence_id: str): self.name_to_sequence_id_mapping[name].append(sequence_id) + + class Builder: + SQLFRAME_DIALECT_KEY = "sqlframe.dialect" + + def __init__(self): + self.dialect = "spark" + + def __getattr__(self, item) -> SparkSession.Builder: + return self + + def __call__(self, *args, **kwargs): + return self + + def config( + self, + key: t.Optional[str] = None, + value: t.Optional[t.Any] = None, + *, + map: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, + ) -> SparkSession.Builder: + if key == self.SQLFRAME_DIALECT_KEY: + self.dialect = value + elif map and self.SQLFRAME_DIALECT_KEY in map: + self.dialect = map[self.SQLFRAME_DIALECT_KEY] + return self + + def getOrCreate(self) -> SparkSession: + spark = SparkSession() + spark.dialect = Dialect.get_or_raise(self.dialect)() + return spark + + @classproperty + def builder(cls) -> Builder: + return cls.Builder() diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py index c54c07e..c1d913f 100644 --- a/sqlglot/dataframe/sql/window.py +++ b/sqlglot/dataframe/sql/window.py @@ -48,7 +48,9 @@ class WindowSpec: return WindowSpec(self.expression.copy()) def sql(self, **kwargs) -> str: - return self.expression.sql(dialect="spark", **kwargs) + from sqlglot.dataframe.sql.session import SparkSession + + return self.expression.sql(dialect=SparkSession().dialect, **kwargs) def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: from sqlglot.dataframe.sql.column import Column |