diff options
-rw-r--r-- | CHANGELOG.md | 17 | ||||
-rw-r--r-- | sqlglot/__init__.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 19 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 38 | ||||
-rw-r--r-- | sqlglot/expressions.py | 74 | ||||
-rw-r--r-- | sqlglot/generator.py | 5 | ||||
-rw-r--r-- | sqlglot/parser.py | 26 | ||||
-rw-r--r-- | sqlglot/tokens.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 26 | ||||
-rw-r--r-- | tests/test_expressions.py | 32 | ||||
-rw-r--r-- | tests/test_parser.py | 3 |
22 files changed, 312 insertions, 45 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 0eba6cc..f16fc70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,23 @@ Changelog ========= +v6.2.0 +------ + +Changes: + +- New: TSQL support + +- Breaking: Removed $ from tokenizer, added @ placeholders + +- Improvement: Nodes can now be removed in transform and replace [8cd81c3](https://github.com/tobymao/sqlglot/commit/8cd81c36561463b9849a8e0c2d70248c5b1feb62) + +- Improvement: Snowflake timestamp support + +- Improvement: Property conversion for CTAS Builder + +- Improvement: Tokenizers are now unique per dialect instance + v6.1.0 ------ diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 3fa40ce..44d349b 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -20,7 +20,7 @@ from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType -__version__ = "6.1.1" +__version__ = "6.2.0" pretty = False diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index f7d03ad..0f80723 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -14,3 +14,4 @@ from sqlglot.dialects.sqlite import SQLite from sqlglot.dialects.starrocks import StarRocks from sqlglot.dialects.tableau import Tableau from sqlglot.dialects.trino import Trino +from sqlglot.dialects.tsql import TSQL diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index f338c81..0120e71 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -27,6 +27,7 @@ class Dialects(str, Enum): STARROCKS = "starrocks" TABLEAU = "tableau" TRINO = "trino" + TSQL = "tsql" class _Dialect(type): @@ -53,7 +54,6 @@ class _Dialect(type): klass.parser_class = getattr(klass, "Parser", Parser) klass.generator_class = getattr(klass, "Generator", Generator) - klass.tokenizer = klass.tokenizer_class() klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] @@ -95,7 +95,6 @@ class Dialect(metaclass=_Dialect): tokenizer_class = None parser_class = None generator_class = None - tokenizer = None @classmethod def get_or_raise(cls, dialect): @@ -138,6 +137,12 @@ class Dialect(metaclass=_Dialect): def transpile(self, code, **opts): return self.generate(self.parse(code), **opts) + @property + def tokenizer(self): + if not hasattr(self, "_tokenizer"): + self._tokenizer = self.tokenizer_class() + return self._tokenizer + def parser(self, **opts): return self.parser_class( **{ @@ -170,7 +175,15 @@ class Dialect(metaclass=_Dialect): def rename_func(name): - return lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" + def _rename(self, expression): + args = ( + self.expressions(expression, flat=True) + if isinstance(expression, exp.Func) and expression.is_var_len_args + else csv(*[self.sql(e) for e in expression.args.values()]) + ) + return f"{name}({args})" + + return _rename def approx_count_distinct_sql(self, expression): diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index ff3a8b1..4ca9e84 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -108,7 +108,7 @@ class DuckDB(Dialect): TRANSFORMS = { **Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})", + exp.Array: rename_func("LIST_VALUE"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 8d6ee78..b5d4f0a 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -106,6 +106,11 @@ class Snowflake(Dialect): "TO_TIMESTAMP": _snowflake_to_timestamp, } + FUNCTION_PARSERS = { + **Parser.FUNCTION_PARSERS, + "DATE_PART": lambda self: self._parse_extract(), + } + COLUMN_OPERATORS = { **Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( @@ -118,10 +123,20 @@ class Snowflake(Dialect): class Tokenizer(Tokenizer): QUOTES = ["'", "$$"] ESCAPE = "\\" + + SINGLE_TOKENS = { + **Tokenizer.SINGLE_TOKENS, + "$": TokenType.DOLLAR, # needed to break for quotes + } + KEYWORDS = { **Tokenizer.KEYWORDS, "QUALIFY": TokenType.QUALIFY, "DOUBLE PRECISION": TokenType.DOUBLE, + "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, + "TIMESTAMP_NTZ": TokenType.TIMESTAMP, + "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, + "TIMESTAMPNTZ": TokenType.TIMESTAMP, } class Generator(Generator): @@ -132,6 +147,11 @@ class Snowflake(Dialect): exp.UnixToTime: _unix_to_time, } + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", + } + def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index a331191..c051178 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -82,6 +82,7 @@ class Spark(Hive): TRANSFORMS = { **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort}}, + exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py new file mode 100644 index 0000000..68bb9bd --- /dev/null +++ b/sqlglot/dialects/tsql.py @@ -0,0 +1,38 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.tokens import Tokenizer, TokenType + + +class TSQL(Dialect): + null_ordering = "nulls_are_small" + time_format = "'yyyy-mm-dd hh:mm:ss'" + + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', ("[", "]")] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "BIT": TokenType.BOOLEAN, + "REAL": TokenType.FLOAT, + "NTEXT": TokenType.TEXT, + "SMALLDATETIME": TokenType.DATETIME, + "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, + "TIME": TokenType.TIMESTAMP, + "VARBINARY": TokenType.BINARY, + "IMAGE": TokenType.IMAGE, + "MONEY": TokenType.MONEY, + "SMALLMONEY": TokenType.SMALLMONEY, + "ROWVERSION": TokenType.ROWVERSION, + "SQL_VARIANT": TokenType.SQL_VARIANT, + "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, + "XML": TokenType.XML, + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "BIT", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.DECIMAL: "NUMERIC", + } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b983bf9..9299132 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,4 +1,5 @@ import inspect +import numbers import re import sys from collections import deque @@ -6,7 +7,7 @@ from copy import deepcopy from enum import auto from sqlglot.errors import ParseError -from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list +from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get class _Expression(type): @@ -350,7 +351,8 @@ class Expression(metaclass=_Expression): Args: fun (function): a function which takes a node as an argument and returns a - new transformed node or the same node without modifications. + new transformed node or the same node without modifications. If the function + returns None, then the corresponding node will be removed from the syntax tree. copy (bool): if set to True a new tree instance is constructed, otherwise the tree is modified in place. @@ -360,9 +362,7 @@ class Expression(metaclass=_Expression): node = self.copy() if copy else self new_node = fun(node, *args, **kwargs) - if new_node is None: - raise ValueError("A transformed node cannot be None") - if not isinstance(new_node, Expression): + if new_node is None or not isinstance(new_node, Expression): return new_node if new_node is not node: new_node.parent = node.parent @@ -843,10 +843,6 @@ class Ordered(Expression): arg_types = {"this": True, "desc": True, "nulls_first": True} -class Properties(Expression): - arg_types = {"expressions": True} - - class Property(Expression): arg_types = {"this": True, "value": True} @@ -891,6 +887,42 @@ class AnonymousProperty(Property): pass +class Properties(Expression): + arg_types = {"expressions": True} + + PROPERTY_KEY_MAPPING = { + "AUTO_INCREMENT": AutoIncrementProperty, + "CHARACTER_SET": CharacterSetProperty, + "COLLATE": CollateProperty, + "COMMENT": SchemaCommentProperty, + "ENGINE": EngineProperty, + "FORMAT": FileFormatProperty, + "LOCATION": LocationProperty, + "PARTITIONED_BY": PartitionedByProperty, + "TABLE_FORMAT": TableFormatProperty, + } + + @classmethod + def from_dict(cls, properties_dict): + expressions = [] + for key, value in properties_dict.items(): + property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) + expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value))) + return cls(expressions=expressions) + + @staticmethod + def _convert_value(value): + if isinstance(value, Expression): + return value + if isinstance(value, str): + return Literal.string(value) + if isinstance(value, numbers.Number): + return Literal.number(value) + if isinstance(value, list): + return Tuple(expressions=[_convert_value(v) for v in value]) + raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'") + + class Qualify(Expression): pass @@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression): ) properties_expression = None if properties: - properties_str = " ".join( - [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()] - ) - properties_expression = maybe_parse( - properties_str, - into=Properties, - dialect=dialect, - **opts, - ) + properties_expression = Properties.from_dict(properties) return Create( this=table_expression, @@ -1650,6 +1674,10 @@ class Star(Expression): return "*" +class Parameter(Expression): + pass + + class Placeholder(Expression): arg_types = {} @@ -1688,6 +1716,7 @@ class DataType(Expression): INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() + TIMESTAMPLTZ = auto() DATE = auto() DATETIME = auto() ARRAY = auto() @@ -1702,6 +1731,13 @@ class DataType(Expression): SERIAL = auto() SMALLSERIAL = auto() BIGSERIAL = auto() + XML = auto() + UNIQUEIDENTIFIER = auto() + MONEY = auto() + SMALLMONEY = auto() + ROWVERSION = auto() + IMAGE = auto() + SQL_VARIANT = auto() @classmethod def build(cls, dtype, **kwargs): @@ -2976,7 +3012,7 @@ def replace_children(expression, fun): else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0] + expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) def column_table_names(expression): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index a445178..d264e59 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -748,6 +748,9 @@ class Generator: def structkwarg_sql(self, expression): return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def parameter_sql(self, expression): + return f"@{self.sql(expression, 'this')}" + def placeholder_sql(self, *_): return "?" @@ -903,7 +906,7 @@ class Generator: return f"UNIQUE ({columns})" def if_sql(self, expression): - return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) + return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))) def in_sql(self, expression): query = expression.args.get("query") diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f46bafe..6ad6391 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -81,6 +81,7 @@ class Parser: TokenType.INTERVAL, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, + TokenType.TIMESTAMPLTZ, TokenType.DATETIME, TokenType.DATE, TokenType.DECIMAL, @@ -92,6 +93,13 @@ class Parser: TokenType.SERIAL, TokenType.SMALLSERIAL, TokenType.BIGSERIAL, + TokenType.XML, + TokenType.UNIQUEIDENTIFIER, + TokenType.MONEY, + TokenType.SMALLMONEY, + TokenType.ROWVERSION, + TokenType.IMAGE, + TokenType.SQL_VARIANT, *NESTED_TYPE_TOKENS, } @@ -233,6 +241,7 @@ class Parser: TIMESTAMPS = { TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, + TokenType.TIMESTAMPLTZ, } SET_OPERATIONS = { @@ -315,6 +324,7 @@ class Parser: TokenType.TRUE: lambda *_: exp.Boolean(this=True), TokenType.FALSE: lambda *_: exp.Boolean(this=False), TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), + TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()), TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text), TokenType.INTRODUCER: lambda self, token: self.expression( @@ -1497,12 +1507,19 @@ class Parser: if type_token in self.TIMESTAMPS: tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ - self._match(TokenType.WITHOUT_TIME_ZONE) if tz: return exp.DataType( this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions, ) + ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ + if ltz: + return exp.DataType( + this=exp.DataType.Type.TIMESTAMPLTZ, + expressions=expressions, + ) + self._match(TokenType.WITHOUT_TIME_ZONE) + return exp.DataType( this=exp.DataType.Type.TIMESTAMP, expressions=expressions, @@ -1845,8 +1862,11 @@ class Parser: def _parse_extract(self): this = self._parse_var() or self._parse_type() - if not self._match(TokenType.FROM): - self.raise_error("Expected FROM after EXTRACT", self._prev) + if self._match(TokenType.FROM): + return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) + + if not self._match(TokenType.COMMA): + self.raise_error("Expected FROM or comma after EXTRACT", self._prev) return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index bd95bc7..7a50fc3 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -41,6 +41,7 @@ class TokenType(AutoName): LR_ARROW = auto() ANNOTATION = auto() DOLLAR = auto() + PARAMETER = auto() SPACE = auto() BREAK = auto() @@ -75,6 +76,7 @@ class TokenType(AutoName): JSON = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() + TIMESTAMPLTZ = auto() DATETIME = auto() DATE = auto() UUID = auto() @@ -86,6 +88,13 @@ class TokenType(AutoName): SERIAL = auto() SMALLSERIAL = auto() BIGSERIAL = auto() + XML = auto() + UNIQUEIDENTIFIER = auto() + MONEY = auto() + SMALLMONEY = auto() + ROWVERSION = auto() + IMAGE = auto() + SQL_VARIANT = auto() # keywords ADD_FILE = auto() @@ -247,6 +256,7 @@ class TokenType(AutoName): WINDOW = auto() WITH = auto() WITH_TIME_ZONE = auto() + WITH_LOCAL_TIME_ZONE = auto() WITHIN_GROUP = auto() WITHOUT_TIME_ZONE = auto() UNIQUE = auto() @@ -340,7 +350,7 @@ class Tokenizer(metaclass=_Tokenizer): "~": TokenType.TILDA, "?": TokenType.PLACEHOLDER, "#": TokenType.ANNOTATION, - "$": TokenType.DOLLAR, + "@": TokenType.PARAMETER, # used for breaking a var like x'y' but nothing else # the token type doesn't matter "'": TokenType.QUOTE, @@ -520,6 +530,7 @@ class Tokenizer(metaclass=_Tokenizer): "WHERE": TokenType.WHERE, "WITH": TokenType.WITH, "WITH TIME ZONE": TokenType.WITH_TIME_ZONE, + "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, "WITHIN GROUP": TokenType.WITHIN_GROUP, "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, "ARRAY": TokenType.ARRAY, @@ -561,6 +572,7 @@ class Tokenizer(metaclass=_Tokenizer): "BYTEA": TokenType.BINARY, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, + "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, "DATE": TokenType.DATE, "DATETIME": TokenType.DATETIME, "UNIQUE": TokenType.UNIQUE, diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6b7bfd3..4e0a3c6 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -228,6 +228,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", }, ) @@ -237,6 +238,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", + "redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", }, ) @@ -246,6 +248,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%y')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%y')", + "redshift": "TO_TIMESTAMP(x, 'YY')", "spark": "TO_TIMESTAMP(x, 'yy')", }, ) @@ -287,6 +290,7 @@ class TestDialect(Validator): "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')", + "redshift": "TO_CHAR(x, 'YYYY-MM-DD')", }, ) self.validate_all( @@ -295,6 +299,7 @@ class TestDialect(Validator): "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", + "redshift": "CAST(x AS TEXT)", }, ) self.validate_all( diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 501301f..f52decb 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -66,6 +66,9 @@ class TestDuckDB(Validator): def test_duckdb(self): self.validate_all( "LIST_VALUE(0, 1, 2)", + read={ + "spark": "ARRAY(0, 1, 2)", + }, write={ "bigquery": "[0, 1, 2]", "duckdb": "LIST_VALUE(0, 1, 2)", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 55086e3..a9b5168 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -131,7 +131,7 @@ class TestHive(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", }, ) self.validate_all( diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py new file mode 100644 index 0000000..1fadb84 --- /dev/null +++ b/tests/dialects/test_oracle.py @@ -0,0 +1,6 @@ +from tests.dialects.test_dialect import Validator + + +class TestOracle(Validator): + def test_oracle(self): + self.validate_identity("SELECT * FROM V$SESSION") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index eb9aa5c..96c299d 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -173,7 +173,7 @@ class TestPresto(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, ) self.validate_all( @@ -181,7 +181,7 @@ class TestPresto(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 2eeff52..165f8e2 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -175,3 +175,48 @@ class TestSnowflake(Validator): "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" }, ) + + def test_timestamps(self): + self.validate_all( + "SELECT CAST(a AS TIMESTAMP)", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)", + }, + ) + self.validate_all( + "SELECT a::TIMESTAMP_LTZ(9)", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPLTZ(9))", + }, + ) + self.validate_all( + "SELECT a::TIMESTAMPLTZ", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)", + }, + ) + self.validate_all( + "SELECT a::TIMESTAMP WITH LOCAL TIME ZONE", + write={ + "snowflake": "SELECT CAST(a AS TIMESTAMPLTZ)", + }, + ) + self.validate_identity("SELECT EXTRACT(month FROM a)") + self.validate_all( + "SELECT EXTRACT('month', a)", + write={ + "snowflake": "SELECT EXTRACT('month' FROM a)", + }, + ) + self.validate_all( + "SELECT DATE_PART('month', a)", + write={ + "snowflake": "SELECT EXTRACT('month' FROM a)", + }, + ) + self.validate_all( + "SELECT DATE_PART(month FROM a::DATETIME)", + write={ + "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8794fed..22f6947 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -44,15 +44,7 @@ class TestSpark(Validator): write={ "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", - "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", - }, - ) - self.validate_all( - "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", - write={ - "presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1", - "hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", - "spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, ) self.validate_all( @@ -86,7 +78,7 @@ COMMENT 'Test comment: blah' PARTITIONED BY ( date STRING ) -STORED AS ICEBERG +USING ICEBERG TBLPROPERTIES ( 'x' = '1' )""", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py new file mode 100644 index 0000000..0619eaa --- /dev/null +++ b/tests/dialects/test_tsql.py @@ -0,0 +1,26 @@ +from tests.dialects.test_dialect import Validator + + +class TestTSQL(Validator): + dialect = "tsql" + + def test_tsql(self): + self.validate_identity('SELECT "x"."y" FROM foo') + + self.validate_all( + "SELECT CAST([a].[b] AS SMALLINT) FROM foo", + write={ + "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', + "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + }, + ) + + def test_types(self): + self.validate_identity("CAST(x AS XML)") + self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)") + self.validate_identity("CAST(x AS MONEY)") + self.validate_identity("CAST(x AS SMALLMONEY)") + self.validate_identity("CAST(x AS ROWVERSION)") + self.validate_identity("CAST(x AS IMAGE)") + self.validate_identity("CAST(x AS SQL_VARIANT)") + self.validate_identity("CAST(x AS BIT)") diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 716e457..59d584c 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -224,9 +224,6 @@ class TestExpressions(unittest.TestCase): self.assertEqual(actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)") self.assertIs(actual_expression_2, expression) - with self.assertRaises(ValueError): - parse_one("a").transform(lambda n: None) - def test_transform_no_infinite_recursion(self): expression = parse_one("a") @@ -247,6 +244,35 @@ class TestExpressions(unittest.TestCase): self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x") + def test_transform_node_removal(self): + expression = parse_one("SELECT a, b FROM x") + + def remove_column_b(node): + if isinstance(node, exp.Column) and node.name == "b": + return None + return node + + self.assertEqual(expression.transform(remove_column_b).sql(), "SELECT a FROM x") + self.assertEqual(expression.transform(lambda _: None), None) + + expression = parse_one("CAST(x AS FLOAT)") + + def remove_non_list_arg(node): + if isinstance(node, exp.DataType): + return None + return node + + self.assertEqual(expression.transform(remove_non_list_arg).sql(), "CAST(x AS )") + + expression = parse_one("SELECT a, b FROM x") + + def remove_all_columns(node): + if isinstance(node, exp.Column): + return None + return node + + self.assertEqual(expression.transform(remove_all_columns).sql(), "SELECT FROM x") + def test_replace(self): expression = parse_one("SELECT a, b FROM x") expression.find(exp.Column).replace(parse_one("c")) diff --git a/tests/test_parser.py b/tests/test_parser.py index 1054103..9e430e2 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -114,6 +114,9 @@ class TestParser(unittest.TestCase): with self.assertRaises(ParseError): parse_one("SELECT FROM x ORDER BY") + def test_parameter(self): + self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1") + def test_annotations(self): expression = parse_one( """ |