blob: d64a818634d5724c47a00c150f20413925185a4f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
import unittest
from sqlglot.expressions import Func
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer
class TestGenerator(unittest.TestCase):
def test_fallback_function_sql(self):
class SpecialUDF(Func):
arg_types = {"a": True, "b": False}
class NewParser(Parser):
FUNCTIONS = SpecialUDF.default_parser_mappings()
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x")
expression = NewParser().parse(tokens)[0]
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")
def test_fallback_function_var_args_sql(self):
class SpecialUDF(Func):
arg_types = {"a": True, "expressions": False}
is_var_len_args = True
class NewParser(Parser):
FUNCTIONS = SpecialUDF.default_parser_mappings()
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
expression = NewParser().parse(tokens)[0]
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
|