summaryrefslogtreecommitdiffstats
path: root/tests/test_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_generator.py')
-rw-r--r--tests/test_generator.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tests/test_generator.py b/tests/test_generator.py
new file mode 100644
index 0000000..d64a818
--- /dev/null
+++ b/tests/test_generator.py
@@ -0,0 +1,30 @@
+import unittest
+
+from sqlglot.expressions import Func
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer
+
+
+class TestGenerator(unittest.TestCase):
+ def test_fallback_function_sql(self):
+ class SpecialUDF(Func):
+ arg_types = {"a": True, "b": False}
+
+ class NewParser(Parser):
+ FUNCTIONS = SpecialUDF.default_parser_mappings()
+
+ tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x")
+ expression = NewParser().parse(tokens)[0]
+ self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")
+
+ def test_fallback_function_var_args_sql(self):
+ class SpecialUDF(Func):
+ arg_types = {"a": True, "expressions": False}
+ is_var_len_args = True
+
+ class NewParser(Parser):
+ FUNCTIONS = SpecialUDF.default_parser_mappings()
+
+ tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
+ expression = NewParser().parse(tokens)[0]
+ self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")