summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/session.py
blob: f518ac2925c0e0e44ab45c052a2cece557cf06a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from __future__ import annotations

import typing as t
import uuid
from collections import defaultdict

import sqlglot
from sqlglot import Dialect, expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input
from sqlglot.helper import classproperty

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput


class SparkSession:
    DEFAULT_DIALECT = "spark"
    _instance = None

    def __init__(self):
        if not hasattr(self, "known_ids"):
            self.known_ids = set()
            self.known_branch_ids = set()
            self.known_sequence_ids = set()
            self.name_to_sequence_id_mapping = defaultdict(list)
            self.incrementing_id = 1
            self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)

    def __new__(cls, *args, **kwargs) -> SparkSession:
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    @property
    def read(self) -> DataFrameReader:
        return DataFrameReader(self)

    def table(self, tableName: str) -> DataFrame:
        return self.read.table(tableName)

    def createDataFrame(
        self,
        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
        schema: t.Optional[SchemaInput] = None,
        samplingRatio: t.Optional[float] = None,
        verifySchema: bool = False,
    ) -> DataFrame:
        from sqlglot.dataframe.sql.dataframe import DataFrame

        if samplingRatio is not None or verifySchema:
            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
        if schema is not None and (
            not isinstance(schema, (StructType, str, list))
            or (isinstance(schema, list) and not isinstance(schema[0], str))
        ):
            raise NotImplementedError("Only schema of either list or string of list supported")
        if not data:
            raise ValueError("Must provide data to create into a DataFrame")

        column_mapping: t.Dict[str, t.Optional[str]]
        if schema is not None:
            column_mapping = get_column_mapping_from_schema_input(schema)
        elif isinstance(data[0], dict):
            column_mapping = {col_name.strip(): None for col_name in data[0]}
        else:
            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}

        data_expressions = [
            exp.Tuple(
                expressions=list(
                    map(
                        lambda x: F.lit(x).expression,
                        row if not isinstance(row, dict) else row.values(),
                    )
                )
            )
            for row in data
        ]

        sel_columns = [
            (
                F.col(name).cast(data_type).alias(name).expression
                if data_type is not None
                else F.col(name).expression
            )
            for name, data_type in column_mapping.items()
        ]

        select_kwargs = {
            "expressions": sel_columns,
            "from": exp.From(
                this=exp.Values(
                    expressions=data_expressions,
                    alias=exp.TableAlias(
                        this=exp.to_identifier(self._auto_incrementing_name),
                        columns=[exp.to_identifier(col_name) for col_name in column_mapping],
                    ),
                ),
            ),
        }

        sel_expression = exp.Select(**select_kwargs)
        return DataFrame(self, sel_expression)

    def sql(self, sqlQuery: str) -> DataFrame:
        expression = sqlglot.parse_one(sqlQuery, read=self.dialect)
        if isinstance(expression, exp.Select):
            df = DataFrame(self, expression)
            df = df._convert_leaf_to_cte()
        elif isinstance(expression, (exp.Create, exp.Insert)):
            select_expression = expression.expression.copy()
            if isinstance(expression, exp.Insert):
                select_expression.set("with", expression.args.get("with"))
                expression.set("with", None)
            del expression.args["expression"]
            df = DataFrame(self, select_expression, output_expression_container=expression)  # type: ignore
            df = df._convert_leaf_to_cte()
        else:
            raise ValueError(
                "Unknown expression type provided in the SQL. Please create an issue with the SQL."
            )
        return df

    @property
    def _auto_incrementing_name(self) -> str:
        name = f"a{self.incrementing_id}"
        self.incrementing_id += 1
        return name

    @property
    def _random_branch_id(self) -> str:
        id = self._random_id
        self.known_branch_ids.add(id)
        return id

    @property
    def _random_sequence_id(self):
        id = self._random_id
        self.known_sequence_ids.add(id)
        return id

    @property
    def _random_id(self) -> str:
        id = "r" + uuid.uuid4().hex
        self.known_ids.add(id)
        return id

    @property
    def _join_hint_names(self) -> t.Set[str]:
        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}

    def _add_alias_to_mapping(self, name: str, sequence_id: str):
        self.name_to_sequence_id_mapping[name].append(sequence_id)

    class Builder:
        SQLFRAME_DIALECT_KEY = "sqlframe.dialect"

        def __init__(self):
            self.dialect = "spark"

        def __getattr__(self, item) -> SparkSession.Builder:
            return self

        def __call__(self, *args, **kwargs):
            return self

        def config(
            self,
            key: t.Optional[str] = None,
            value: t.Optional[t.Any] = None,
            *,
            map: t.Optional[t.Dict[str, t.Any]] = None,
            **kwargs: t.Any,
        ) -> SparkSession.Builder:
            if key == self.SQLFRAME_DIALECT_KEY:
                self.dialect = value
            elif map and self.SQLFRAME_DIALECT_KEY in map:
                self.dialect = map[self.SQLFRAME_DIALECT_KEY]
            return self

        def getOrCreate(self) -> SparkSession:
            spark = SparkSession()
            spark.dialect = Dialect.get_or_raise(self.dialect)
            return spark

    @classproperty
    def builder(cls) -> Builder:
        return cls.Builder()