summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/session.py
blob: 1ea86d1956d9bfc9ccb1767ebc2b111390f53ebc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import annotations

import typing as t
import uuid
from collections import defaultdict

import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.readwriter import DataFrameReader
from sqlglot.dataframe.sql.types import StructType
from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput


class SparkSession:
    known_ids: t.ClassVar[t.Set[str]] = set()
    known_branch_ids: t.ClassVar[t.Set[str]] = set()
    known_sequence_ids: t.ClassVar[t.Set[str]] = set()
    name_to_sequence_id_mapping: t.ClassVar[t.Dict[str, t.List[str]]] = defaultdict(list)

    def __init__(self):
        self.incrementing_id = 1

    def __getattr__(self, name: str) -> SparkSession:
        return self

    def __call__(self, *args, **kwargs) -> SparkSession:
        return self

    @property
    def read(self) -> DataFrameReader:
        return DataFrameReader(self)

    def table(self, tableName: str) -> DataFrame:
        return self.read.table(tableName)

    def createDataFrame(
        self,
        data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]],
        schema: t.Optional[SchemaInput] = None,
        samplingRatio: t.Optional[float] = None,
        verifySchema: bool = False,
    ) -> DataFrame:
        from sqlglot.dataframe.sql.dataframe import DataFrame

        if samplingRatio is not None or verifySchema:
            raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
        if schema is not None and (
            not isinstance(schema, (StructType, str, list))
            or (isinstance(schema, list) and not isinstance(schema[0], str))
        ):
            raise NotImplementedError("Only schema of either list or string of list supported")
        if not data:
            raise ValueError("Must provide data to create into a DataFrame")

        column_mapping: t.Dict[str, t.Optional[str]]
        if schema is not None:
            column_mapping = get_column_mapping_from_schema_input(schema)
        elif isinstance(data[0], dict):
            column_mapping = {col_name.strip(): None for col_name in data[0]}
        else:
            column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)}

        data_expressions = [
            exp.Tuple(
                expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
            )
            for row in data
        ]

        sel_columns = [
            F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
            for name, data_type in column_mapping.items()
        ]

        select_kwargs = {
            "expressions": sel_columns,
            "from": exp.From(
                expressions=[
                    exp.Subquery(
                        this=exp.Values(expressions=data_expressions),
                        alias=exp.TableAlias(
                            this=exp.to_identifier(self._auto_incrementing_name),
                            columns=[exp.to_identifier(col_name) for col_name in column_mapping],
                        ),
                    )
                ]
            ),
        }

        sel_expression = exp.Select(**select_kwargs)
        return DataFrame(self, sel_expression)

    def sql(self, sqlQuery: str) -> DataFrame:
        expression = sqlglot.parse_one(sqlQuery, read="spark")
        if isinstance(expression, exp.Select):
            df = DataFrame(self, expression)
            df = df._convert_leaf_to_cte()
        elif isinstance(expression, (exp.Create, exp.Insert)):
            select_expression = expression.expression.copy()
            if isinstance(expression, exp.Insert):
                select_expression.set("with", expression.args.get("with"))
                expression.set("with", None)
            del expression.args["expression"]
            df = DataFrame(self, select_expression, output_expression_container=expression)
            df = df._convert_leaf_to_cte()
        else:
            raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
        return df

    @property
    def _auto_incrementing_name(self) -> str:
        name = f"a{self.incrementing_id}"
        self.incrementing_id += 1
        return name

    @property
    def _random_name(self) -> str:
        return f"a{str(uuid.uuid4())[:8]}"

    @property
    def _random_branch_id(self) -> str:
        id = self._random_id
        self.known_branch_ids.add(id)
        return id

    @property
    def _random_sequence_id(self):
        id = self._random_id
        self.known_sequence_ids.add(id)
        return id

    @property
    def _random_id(self) -> str:
        id = f"a{str(uuid.uuid4())[:8]}"
        self.known_ids.add(id)
        return id

    @property
    def _join_hint_names(self) -> t.Set[str]:
        return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}

    def _add_alias_to_mapping(self, name: str, sequence_id: str):
        self.name_to_sequence_id_mapping[name].append(sequence_id)