summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dataframe/sql
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dataframe/sql')
-rw-r--r--sqlglot/dataframe/sql/column.py22
-rw-r--r--sqlglot/dataframe/sql/dataframe.py34
-rw-r--r--sqlglot/dataframe/sql/functions.py8
-rw-r--r--sqlglot/dataframe/sql/normalize.py3
-rw-r--r--sqlglot/dataframe/sql/readwriter.py23
-rw-r--r--sqlglot/dataframe/sql/session.py65
-rw-r--r--sqlglot/dataframe/sql/window.py4
7 files changed, 116 insertions, 43 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index fcfd71e..3acf494 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -5,7 +5,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
-from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable
if t.TYPE_CHECKING:
@@ -15,19 +14,20 @@ if t.TYPE_CHECKING:
class Column:
def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]):
+ from sqlglot.dataframe.sql.session import SparkSession
+
if isinstance(expression, Column):
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
-
- expression = sqlglot.maybe_parse(expression, dialect="spark")
+ elif not isinstance(expression, exp.Column):
+ expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform(
+ SparkSession().dialect.normalize_identifier, copy=False
+ )
if expression is None:
raise ValueError(f"Could not parse {expression}")
- if isinstance(expression, exp.Column):
- expression.transform(Spark.normalize_identifier, copy=False)
-
- self.expression: exp.Expression = expression
+ self.expression: exp.Expression = expression # type: ignore
def __repr__(self):
return repr(self.expression)
@@ -207,7 +207,9 @@ class Column:
return Column(expression)
def sql(self, **kwargs) -> str:
- return self.expression.sql(**{"dialect": "spark", **kwargs})
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
def alias(self, name: str) -> Column:
new_expression = exp.alias_(self.column_expression, name)
@@ -264,9 +266,11 @@ class Column:
Functionality Difference: PySpark cast accepts a datatype instance of the datatype class
Sqlglot doesn't currently replicate this class so it only accepts a string
"""
+ from sqlglot.dataframe.sql.session import SparkSession
+
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
- return Column(exp.cast(self.column_expression, dataType, dialect="spark"))
+ return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect))
def startswith(self, value: t.Union[str, Column]) -> Column:
value = self._lit(value) if not isinstance(value, Column) else value
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 64cceea..f515608 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -1,12 +1,13 @@
from __future__ import annotations
import functools
+import logging
import typing as t
import zlib
from copy import copy
import sqlglot
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.group import GroupedData
@@ -18,6 +19,7 @@ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dataframe.sql.window import Window
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer import optimize as optimize_func
+from sqlglot.optimizer.qualify_columns import quote_identifiers
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import (
@@ -27,7 +29,9 @@ if t.TYPE_CHECKING:
OutputExpressionContainer,
)
from sqlglot.dataframe.sql.session import SparkSession
+ from sqlglot.dialects.dialect import DialectType
+logger = logging.getLogger("sqlglot")
JOIN_HINTS = {
"BROADCAST",
@@ -264,7 +268,9 @@ class DataFrame:
@classmethod
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
- value = expression.sql(dialect="spark").encode("utf-8")
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]
def _get_select_expressions(
@@ -291,7 +297,15 @@ class DataFrame:
select_expressions.append(expression_select_pair) # type: ignore
return select_expressions
- def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
+ def sql(
+ self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
+ ) -> t.List[str]:
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
+ logger.warning(
+ f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
+ )
df = self._resolve_pending_hints()
select_expressions = df._get_select_expressions()
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
@@ -299,7 +313,10 @@ class DataFrame:
for expression_type, select_expression in select_expressions:
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
if optimize:
- select_expression = t.cast(exp.Select, optimize_func(select_expression))
+ quote_identifiers(select_expression)
+ select_expression = t.cast(
+ exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
+ )
select_expression = df._replace_cte_names_with_hashes(select_expression)
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
if expression_type == exp.Cache:
@@ -313,10 +330,12 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.sql("spark")
+ expression.alias_or_name: expression.type.sql(
+ dialect=SparkSession().dialect
+ )
for expression in select_expression.expressions
},
- dialect="spark",
+ dialect=SparkSession().dialect,
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
@@ -345,7 +364,8 @@ class DataFrame:
output_expressions.append(expression)
return [
- expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
+ expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
+ for expression in output_expressions
]
def copy(self, **kwargs) -> DataFrame:
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 4002cfe..d0ae50c 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -368,9 +368,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- if ignorenulls is not None:
- return Column.invoke_anonymous_function(col, "FIRST", ignorenulls)
- return Column.invoke_anonymous_function(col, "FIRST")
+ return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls)
def grouping_id(*cols: ColumnOrName) -> Column:
@@ -394,9 +392,7 @@ def isnull(col: ColumnOrName) -> Column:
def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
- if ignorenulls is not None:
- return Column.invoke_anonymous_function(col, "LAST", ignorenulls)
- return Column.invoke_anonymous_function(col, "LAST")
+ return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls)
def monotonically_increasing_id() -> Column:
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index 4eec782..f68bacb 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -5,7 +5,6 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
-from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
@@ -20,7 +19,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
- Spark.normalize_identifier(identifier)
+ identifier.transform(spark.dialect.normalize_identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index 9d87d4a..0804486 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -4,7 +4,6 @@ import typing as t
import sqlglot
from sqlglot import expressions as exp
-from sqlglot.dialects import Spark
from sqlglot.helper import object_to_dict
if t.TYPE_CHECKING:
@@ -18,15 +17,25 @@ class DataFrameReader:
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame
+ from sqlglot.dataframe.sql.session import SparkSession
- sqlglot.schema.add_table(tableName, dialect="spark")
+ sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect)
return DataFrame(
self.spark,
exp.Select()
- .from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
+ .from_(
+ exp.to_table(tableName, dialect=SparkSession().dialect).transform(
+ SparkSession().dialect.normalize_identifier
+ )
+ )
.select(
- *(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
+ *(
+ column
+ for column in sqlglot.schema.column_names(
+ tableName, dialect=SparkSession().dialect
+ )
+ )
),
)
@@ -63,6 +72,8 @@ class DataFrameWriter:
return self.copy(by_name=True)
def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter:
+ from sqlglot.dataframe.sql.session import SparkSession
+
output_expression_container = exp.Insert(
**{
"this": exp.to_table(tableName),
@@ -71,7 +82,9 @@ class DataFrameWriter:
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
- columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
+ columns = sqlglot.schema.column_names(
+ tableName, only_visible=True, dialect=SparkSession().dialect
+ )
df = df._convert_leaf_to_cte().select(*columns)
return self.copy(_df=df)
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index b883359..531ee17 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -5,31 +5,35 @@ import uuid
from collections import defaultdict
import sqlglot
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
+from sqlglot.helper import classproperty
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
- known_ids: t.ClassVar[t.Set[str]] = set()
- known_branch_ids: t.ClassVar[t.Set[str]] = set()
- known_sequence_ids: t.ClassVar[t.Set[str]] = set()
- name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
+ DEFAULT_DIALECT = "spark"
+ _instance = None
def __init__(self):
- self.incrementing_id = 1
-
- def __getattr__(self, name: str) -> SparkSession:
- return self
-
- def __call__(self, *args, **kwargs) -> SparkSession:
- return self
+ if not hasattr(self, "known_ids"):
+ self.known_ids = set()
+ self.known_branch_ids = set()
+ self.known_sequence_ids = set()
+ self.name_to_sequence_id_mapping = defaultdict(list)
+ self.incrementing_id = 1
+ self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
+
+ def __new__(cls, *args, **kwargs) -> SparkSession:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
@property
def read(self) -> DataFrameReader:
@@ -101,7 +105,7 @@ class SparkSession:
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
- expression = sqlglot.parse_one(sqlQuery, read="spark")
+ expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
@@ -149,3 +153,38 @@ class SparkSession:
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)
+
+ class Builder:
+ SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
+
+ def __init__(self):
+ self.dialect = "spark"
+
+ def __getattr__(self, item) -> SparkSession.Builder:
+ return self
+
+ def __call__(self, *args, **kwargs):
+ return self
+
+ def config(
+ self,
+ key: t.Optional[str] = None,
+ value: t.Optional[t.Any] = None,
+ *,
+ map: t.Optional[t.Dict[str, t.Any]] = None,
+ **kwargs: t.Any,
+ ) -> SparkSession.Builder:
+ if key == self.SQLFRAME_DIALECT_KEY:
+ self.dialect = value
+ elif map and self.SQLFRAME_DIALECT_KEY in map:
+ self.dialect = map[self.SQLFRAME_DIALECT_KEY]
+ return self
+
+ def getOrCreate(self) -> SparkSession:
+ spark = SparkSession()
+ spark.dialect = Dialect.get_or_raise(self.dialect)()
+ return spark
+
+ @classproperty
+ def builder(cls) -> Builder:
+ return cls.Builder()
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
index c54c07e..c1d913f 100644
--- a/sqlglot/dataframe/sql/window.py
+++ b/sqlglot/dataframe/sql/window.py
@@ -48,7 +48,9 @@ class WindowSpec:
return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
- return self.expression.sql(dialect="spark", **kwargs)
+ from sqlglot.dataframe.sql.session import SparkSession
+
+ return self.expression.sql(dialect=SparkSession().dialect, **kwargs)
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column