summaryrefslogtreecommitdiffstats
path: root/tests/test_generator.py
blob: e21aa2512fa6374028913d33afd9ef16419f9b8f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import time
import unittest

from sqlglot import exp, parse_one
from sqlglot.expressions import Func
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer


class TestGenerator(unittest.TestCase):
    def test_fallback_function_sql(self):
        class SpecialUDF(Func):
            arg_types = {"a": True, "b": False}

        class NewParser(Parser):
            FUNCTIONS = SpecialUDF.default_parser_mappings()

        tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x")
        expression = NewParser().parse(tokens)[0]
        self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")

    def test_fallback_function_var_args_sql(self):
        class SpecialUDF(Func):
            arg_types = {"a": True, "expressions": False}
            is_var_len_args = True

        class NewParser(Parser):
            FUNCTIONS = SpecialUDF.default_parser_mappings()

        tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
        expression = NewParser().parse(tokens)[0]
        self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")

        self.assertEqual(
            exp.DateTrunc(this=exp.to_column("event_date"), unit=exp.var("MONTH")).sql(),
            "DATE_TRUNC('MONTH', event_date)",
        )

    def test_identify(self):
        assert parse_one("x").sql(identify=True) == '"x"'
        assert parse_one("x").sql(identify="always") == '"x"'
        assert parse_one("X").sql(identify="always") == '"X"'
        assert parse_one("x").sql(identify="safe") == '"x"'
        assert parse_one("X").sql(identify="safe") == "X"
        assert parse_one("x as 1").sql(identify="safe") == '"x" AS "1"'
        assert parse_one("X as 1").sql(identify="safe") == 'X AS "1"'

    def test_generate_nested_binary(self):
        sql = "SELECT 'foo'" + (" || 'foo'" * 500)

        now = time.time()
        self.assertEqual(parse_one(sql).sql(), sql)
        self.assertLessEqual(time.time() - now, 0.1)