summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
blob: 939f2fd52d66aa2a6b202cf52c3191ab30c5ece1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get


def _parse_datediff(args: t.Sequence) -> exp.Expression:
    """
    Although Spark docs don't mention the "unit" argument, Spark3 added support for
    it at some point. Databricks also supports this variation (see below).

    For example, in spark-sql (v3.3.1):
    - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
    - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4

    See also:
    - https://docs.databricks.com/sql/language-manual/functions/datediff3.html
    - https://docs.databricks.com/sql/language-manual/functions/datediff.html
    """
    unit = None
    this = seq_get(args, 0)
    expression = seq_get(args, 1)

    if len(args) == 3:
        unit = this
        this = args[2]

    return exp.DateDiff(
        this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
    )


class Spark(Spark2):
    class Parser(Spark2.Parser):
        FUNCTIONS = {
            **Spark2.Parser.FUNCTIONS,  # type: ignore
            "DATEDIFF": _parse_datediff,
        }

    class Generator(Spark2.Generator):
        TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
        TRANSFORMS.pop(exp.DateDiff)

        def datediff_sql(self, expression: exp.DateDiff) -> str:
            unit = self.sql(expression, "unit")
            end = self.sql(expression, "this")
            start = self.sql(expression, "expression")

            if unit:
                return self.func("DATEDIFF", unit, start, end)

            return self.func("DATEDIFF", end, start)