1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
|
from __future__ import annotations
import typing as t
from sqlglot import exp
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.spark2 import Spark2
from sqlglot.helper import seq_get
def _parse_datediff(args: t.List) -> exp.Expression:
"""
Although Spark docs don't mention the "unit" argument, Spark3 added support for
it at some point. Databricks also supports this variant (see below).
For example, in spark-sql (v3.3.1):
- SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4
- SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4
See also:
- https://docs.databricks.com/sql/language-manual/functions/datediff3.html
- https://docs.databricks.com/sql/language-manual/functions/datediff.html
"""
unit = None
this = seq_get(args, 0)
expression = seq_get(args, 1)
if len(args) == 3:
unit = this
this = args[2]
return exp.DateDiff(
this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit
)
class Spark(Spark2):
class Tokenizer(Spark2.Tokenizer):
RAW_STRINGS = [
(prefix + q, q)
for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
for prefix in ("r", "R")
]
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": lambda args: exp.AnyValue(
this=seq_get(args, 0), ignore_nulls=seq_get(args, 1)
),
"DATEDIFF": _parse_datediff,
}
FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("ANY_VALUE")
def _parse_generated_as_identity(
self,
) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint:
this = super()._parse_generated_as_identity()
if this.expression:
return self.expression(exp.ComputedColumnConstraint, this=this.expression)
return this
class Generator(Spark2.Generator):
TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
}
TRANSFORMS = {
**Spark2.Generator.TRANSFORMS,
exp.StartsWith: rename_func("STARTSWITH"),
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
exp.TryCast: lambda self, e: self.trycast_sql(e)
if e.args.get("safe")
else self.cast_sql(e),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
def anyvalue_sql(self, expression: exp.AnyValue) -> str:
return self.function_fallback_sql(expression)
def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = self.sql(expression, "unit")
end = self.sql(expression, "this")
start = self.sql(expression, "expression")
if unit:
return self.func("DATEDIFF", unit, start, end)
return self.func("DATEDIFF", end, start)
|