summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
blob: e828b9bb9bc655b289fc8e417dc657ffdade80b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast
from sqlglot.helper import ensure_list, seq_get
from sqlglot.transforms import (
    ctas_with_tmp_tables_to_create_tmp_view,
    remove_unique_constraints,
    preprocess,
    move_partitioned_by_to_schema_columns,
)


def _build_datediff(args: t.List) -> exp.Expression:
    """
    Although Spark docs don't mention the "unit" argument, Spark3 added support for
    it at some point. Databricks also supports this variant (see below).

    For example, in spark-sql (v3.3.1):
    - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
    - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4

    See also:
    - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
    - https://docs.databricks.com/sql/language-manual/functions/datediff.html
    """
    unit = None
    this = seq_get(args, 0)
    expression = seq_get(args, 1)

    if len(args) == 3:
        unit = this
        this = args[2]

    return exp.DateDiff(
        this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
    )


def _normalize_partition(e: exp.Expression) -> exp.Expression:
    """Normalize the expressions in PARTITION BY (<expression>, <expression>, ...)"""
    if isinstance(e, str):
        return exp.to_identifier(e)
    if isinstance(e, exp.Literal):
        return exp.to_identifier(e.name)
    return e


class Spark(Spark2):
    class Tokenizer(Spark2.Tokenizer):
        RAW_STRINGS = [
            (prefix + q, q)
            for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
            for prefix in ("r", "R")
        ]

    class Parser(Spark2.Parser):
        FUNCTIONS = {
            **Spark2.Parser.FUNCTIONS,
            "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
            "DATEDIFF": _build_datediff,
            "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"),
            "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"),
            "TRY_ELEMENT_AT": lambda args: exp.Bracket(
                this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
            ),
        }

        def _parse_generated_as_identity(
            self,
        ) -> (
            exp.GeneratedAsIdentityColumnConstraint
            | exp.ComputedColumnConstraint
            | exp.GeneratedAsRowColumnConstraint
        ):
            this = super()._parse_generated_as_identity()
            if this.expression:
                return self.expression(exp.ComputedColumnConstraint, this=this.expression)
            return this

    class Generator(Spark2.Generator):
        SUPPORTS_TO_NUMBER = True

        TYPE_MAPPING = {
            **Spark2.Generator.TYPE_MAPPING,
            exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
            exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
            exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
            exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP_LTZ",
            exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP_NTZ",
        }

        TRANSFORMS = {
            **Spark2.Generator.TRANSFORMS,
            exp.ArrayConstructCompact: lambda self, e: self.func(
                "ARRAY_COMPACT", self.func("ARRAY", *e.expressions)
            ),
            exp.Create: preprocess(
                [
                    remove_unique_constraints,
                    lambda e: ctas_with_tmp_tables_to_create_tmp_view(
                        e, temporary_storage_provider
                    ),
                    move_partitioned_by_to_schema_columns,
                ]
            ),
            exp.PartitionedByProperty: lambda self,
            e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}",
            exp.StartsWith: rename_func("STARTSWITH"),
            exp.TimestampAdd: lambda self, e: self.func(
                "DATEADD", unit_to_var(e), e.expression, e.this
            ),
            exp.TryCast: lambda self, e: (
                self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
            ),
        }
        TRANSFORMS.pop(exp.AnyValue)
        TRANSFORMS.pop(exp.DateDiff)
        TRANSFORMS.pop(exp.Group)

        def bracket_sql(self, expression: exp.Bracket) -> str:
            if expression.args.get("safe"):
                key = seq_get(self.bracket_offset_expressions(expression), 0)
                return self.func("TRY_ELEMENT_AT", expression.this, key)

            return super().bracket_sql(expression)

        def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
            return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"

        def anyvalue_sql(self, expression: exp.AnyValue) -> str:
            return self.function_fallback_sql(expression)

        def datediff_sql(self, expression: exp.DateDiff) -> str:
            end = self.sql(expression, "this")
            start = self.sql(expression, "expression")

            if expression.unit:
                return self.func("DATEDIFF", unit_to_var(expression), start, end)

            return self.func("DATEDIFF", end, start)