diff options
Diffstat (limited to 'sqlglot/dataframe/sql/session.py')
-rw-r--r-- | sqlglot/dataframe/sql/session.py | 65 |
1 files changed, 52 insertions, 13 deletions
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index b883359..531ee17 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -5,31 +5,35 @@ import uuid from collections import defaultdict import sqlglot -from sqlglot import expressions as exp +from sqlglot import Dialect, expressions as exp from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.readwriter import DataFrameReader from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input +from sqlglot.helper import classproperty if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput class SparkSession: - known_ids: t.ClassVar[t.Set[str]] = set() - known_branch_ids: t.ClassVar[t.Set[str]] = set() - known_sequence_ids: t.ClassVar[t.Set[str]] = set() - name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list) + DEFAULT_DIALECT = "spark" + _instance = None def __init__(self): - self.incrementing_id = 1 - - def __getattr__(self, name: str) -> SparkSession: - return self - - def __call__(self, *args, **kwargs) -> SparkSession: - return self + if not hasattr(self, "known_ids"): + self.known_ids = set() + self.known_branch_ids = set() + self.known_sequence_ids = set() + self.name_to_sequence_id_mapping = defaultdict(list) + self.incrementing_id = 1 + self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() + + def __new__(cls, *args, **kwargs) -> SparkSession: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance @property def read(self) -> DataFrameReader: @@ -101,7 +105,7 @@ class SparkSession: return DataFrame(self, sel_expression) def sql(self, sqlQuery: str) -> DataFrame: - expression = sqlglot.parse_one(sqlQuery, read="spark") + expression = sqlglot.parse_one(sqlQuery, read=self.dialect) if isinstance(expression, exp.Select): df = DataFrame(self, expression) df = df._convert_leaf_to_cte() @@ -149,3 +153,38 @@ class SparkSession: def _add_alias_to_mapping(self, name: str, sequence_id: str): self.name_to_sequence_id_mapping[name].append(sequence_id) + + class Builder: + SQLFRAME_DIALECT_KEY = "sqlframe.dialect" + + def __init__(self): + self.dialect = "spark" + + def __getattr__(self, item) -> SparkSession.Builder: + return self + + def __call__(self, *args, **kwargs): + return self + + def config( + self, + key: t.Optional[str] = None, + value: t.Optional[t.Any] = None, + *, + map: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, + ) -> SparkSession.Builder: + if key == self.SQLFRAME_DIALECT_KEY: + self.dialect = value + elif map and self.SQLFRAME_DIALECT_KEY in map: + self.dialect = map[self.SQLFRAME_DIALECT_KEY] + return self + + def getOrCreate(self) -> SparkSession: + spark = SparkSession() + spark.dialect = Dialect.get_or_raise(self.dialect)() + return spark + + @classproperty + def builder(cls) -> Builder: + return cls.Builder() |