summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dataframe/sql/session.py')
-rw-r--r--sqlglot/dataframe/sql/session.py65
1 files changed, 52 insertions, 13 deletions
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index b883359..531ee17 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -5,31 +5,35 @@ import uuid
from collections import defaultdict
import sqlglot
-from sqlglot import expressions as exp
+from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
+from sqlglot.helper import classproperty
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput
class SparkSession:
- known_ids: t.ClassVar[t.Set[str]] = set()
- known_branch_ids: t.ClassVar[t.Set[str]] = set()
- known_sequence_ids: t.ClassVar[t.Set[str]] = set()
- name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)
+ DEFAULT_DIALECT = "spark"
+ _instance = None
def __init__(self):
- self.incrementing_id = 1
-
- def __getattr__(self, name: str) -> SparkSession:
- return self
-
- def __call__(self, *args, **kwargs) -> SparkSession:
- return self
+ if not hasattr(self, "known_ids"):
+ self.known_ids = set()
+ self.known_branch_ids = set()
+ self.known_sequence_ids = set()
+ self.name_to_sequence_id_mapping = defaultdict(list)
+ self.incrementing_id = 1
+ self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)()
+
+ def __new__(cls, *args, **kwargs) -> SparkSession:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
@property
def read(self) -> DataFrameReader:
@@ -101,7 +105,7 @@ class SparkSession:
return DataFrame(self, sel_expression)
def sql(self, sqlQuery: str) -> DataFrame:
- expression = sqlglot.parse_one(sqlQuery, read="spark")
+ expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
if isinstance(expression, exp.Select):
df = DataFrame(self, expression)
df = df._convert_leaf_to_cte()
@@ -149,3 +153,38 @@ class SparkSession:
def _add_alias_to_mapping(self, name: str, sequence_id: str):
self.name_to_sequence_id_mapping[name].append(sequence_id)
+
+ class Builder:
+ SQLFRAME_DIALECT_KEY = "sqlframe.dialect"
+
+ def __init__(self):
+ self.dialect = "spark"
+
+ def __getattr__(self, item) -> SparkSession.Builder:
+ return self
+
+ def __call__(self, *args, **kwargs):
+ return self
+
+ def config(
+ self,
+ key: t.Optional[str] = None,
+ value: t.Optional[t.Any] = None,
+ *,
+ map: t.Optional[t.Dict[str, t.Any]] = None,
+ **kwargs: t.Any,
+ ) -> SparkSession.Builder:
+ if key == self.SQLFRAME_DIALECT_KEY:
+ self.dialect = value
+ elif map and self.SQLFRAME_DIALECT_KEY in map:
+ self.dialect = map[self.SQLFRAME_DIALECT_KEY]
+ return self
+
+ def getOrCreate(self) -> SparkSession:
+ spark = SparkSession()
+ spark.dialect = Dialect.get_or_raise(self.dialect)()
+ return spark
+
+ @classproperty
+ def builder(cls) -> Builder:
+ return cls.Builder()