From 2228e192dc1a582aa2ae004f20c692f6c7aeb853 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 23 Jan 2023 09:43:00 +0100 Subject: Merging upstream version 10.5.6. Signed-off-by: Daniel Baumann --- README.md | 2 +- sqlglot/__init__.py | 2 +- sqlglot/dialects/__init__.py | 1 + sqlglot/dialects/bigquery.py | 5 +++ sqlglot/dialects/clickhouse.py | 31 +++++++++++++ sqlglot/dialects/dialect.py | 3 +- sqlglot/dialects/hive.py | 4 +- sqlglot/dialects/mysql.py | 10 ++++- sqlglot/dialects/postgres.py | 5 +-- sqlglot/dialects/presto.py | 62 ++++++++++++++++++------- sqlglot/dialects/redshift.py | 73 ++++++++++++++++++++++++++++++ sqlglot/dialects/snowflake.py | 3 +- sqlglot/dialects/spark.py | 1 + sqlglot/dialects/sqlite.py | 16 +++++++ sqlglot/dialects/teradata.py | 87 ++++++++++++++++++++++++++++++++++++ sqlglot/dialects/tsql.py | 28 +++++++----- sqlglot/expressions.py | 68 +++++++++++++++++++++++----- sqlglot/generator.py | 82 +++++++++++++++++++++++---------- sqlglot/optimizer/optimizer.py | 5 ++- sqlglot/optimizer/qualify_columns.py | 4 +- sqlglot/parser.py | 80 ++++++++++++++++++++++++--------- sqlglot/tokens.py | 40 ++++++++++++----- tests/dialects/test_bigquery.py | 11 +++++ tests/dialects/test_clickhouse.py | 7 +++ tests/dialects/test_dialect.py | 31 ++++++++----- tests/dialects/test_hive.py | 26 ++++++++++- tests/dialects/test_mysql.py | 11 +++++ tests/dialects/test_postgres.py | 8 ---- tests/dialects/test_presto.py | 13 ++++++ tests/dialects/test_redshift.py | 82 ++++++++++++++++++++++++++++++++- tests/dialects/test_spark.py | 9 +++- tests/dialects/test_teradata.py | 23 ++++++++++ tests/dialects/test_tsql.py | 7 +++ tests/fixtures/identity.sql | 29 +++++++++++- tests/fixtures/pretty.sql | 20 +++++++++ tests/test_expressions.py | 10 +++++ tests/test_optimizer.py | 4 +- tests/test_parser.py | 6 +++ tests/test_transpile.py | 8 ++++ 39 files changed, 785 insertions(+), 132 deletions(-) create mode 100644 sqlglot/dialects/teradata.py create mode 100644 tests/dialects/test_teradata.py diff --git a/README.md b/README.md index 85a76e5..0416521 100644 --- a/README.md +++ b/README.md @@ -462,7 +462,7 @@ make check # Set SKIP_INTEGRATION=1 to skip integration tests | Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide | | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | | tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) | -| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) | +| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76E-5 (0.080) | | long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) | | crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) | diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 87fa081..f2db4f1 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -32,7 +32,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.5.2" +__version__ = "10.5.6" pretty = False diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 2e42e7d..2084681 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -15,5 +15,6 @@ from sqlglot.dialects.spark import Spark from sqlglot.dialects.sqlite import SQLite from sqlglot.dialects.starrocks import StarRocks from sqlglot.dialects.tableau import Tableau +from sqlglot.dialects.teradata import Teradata from sqlglot.dialects.trino import Trino from sqlglot.dialects.tsql import TSQL diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index f0089e1..9ddfbea 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -165,6 +165,11 @@ class BigQuery(Dialect): TokenType.TABLE, } + ID_VAR_TOKENS = { + *parser.Parser.ID_VAR_TOKENS, # type: ignore + TokenType.VALUES, + } + class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 04d46d2..1c173a4 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql +from sqlglot.errors import ParseError from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -72,6 +73,30 @@ class ClickHouse(Dialect): return this + def _parse_position(self) -> exp.Expression: + this = super()._parse_position() + # clickhouse position args are swapped + substr = this.this + this.args["this"] = this.args.get("substr") + this.args["substr"] = substr + return this + + # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ + def _parse_cte(self) -> exp.Expression: + index = self._index + try: + # WITH AS + return super()._parse_cte() + except ParseError: + # WITH AS + self._retreat(index) + statement = self._parse_statement() + + if statement and isinstance(statement.this, exp.Alias): + self.raise_error("Expected CTE to have alias") + + return self.expression(exp.CTE, this=statement, alias=statement and statement.this) + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -110,3 +135,9 @@ class ClickHouse(Dialect): params = self.format_args(self.expressions(expression, params_name)) args = self.format_args(self.expressions(expression, args_name)) return f"({params})({args})" + + def cte_sql(self, expression: exp.CTE) -> str: + if isinstance(expression.this, exp.Alias): + return self.sql(expression, "this") + + return super().cte_sql(expression) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1c840da..0c2beba 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -33,6 +33,7 @@ class Dialects(str, Enum): TSQL = "tsql" DATABRICKS = "databricks" DRILL = "drill" + TERADATA = "teradata" class _Dialect(type): @@ -368,7 +369,7 @@ def locate_to_strposition(args): ) -def strposition_to_local_sql(self, expression): +def strposition_to_locate_sql(self, expression): args = self.format_args( expression.args.get("substr"), expression.this, expression.args.get("position") ) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ead13b1..ddfd1e8 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, no_trycast_sql, rename_func, - strposition_to_local_sql, + strposition_to_locate_sql, struct_extract_sql, timestrtotime_sql, var_map_sql, @@ -297,7 +297,7 @@ class Hive(Dialect): exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", - exp.StrPosition: strposition_to_local_sql, + exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date, exp.StrToTime: _str_to_time, exp.StrToUnix: _str_to_unix, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 0fd7992..1bddfe1 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, - strposition_to_local_sql, + strposition_to_locate_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -122,6 +122,8 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "LONGTEXT": TokenType.LONGTEXT, "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, @@ -442,7 +444,7 @@ class MySQL(Dialect): exp.Trim: _trim_sql, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), - exp.StrPosition: strposition_to_local_sql, + exp.StrPosition: strposition_to_locate_sql, } ROOT_PROPERTIES = { @@ -454,6 +456,10 @@ class MySQL(Dialect): exp.LikeProperty, } + TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() + TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) + TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) + WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() def show_sql(self, expression): diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index f3fec31..6f597f1 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -223,19 +223,15 @@ class Postgres(Dialect): "~~*": TokenType.ILIKE, "~*": TokenType.IRLIKE, "~": TokenType.RLIKE, - "ALWAYS": TokenType.ALWAYS, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, - "BY DEFAULT": TokenType.BY_DEFAULT, "CHARACTER VARYING": TokenType.VARCHAR, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, - "GENERATED": TokenType.GENERATED, "GRANT": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, - "IDENTITY": TokenType.IDENTITY, "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, @@ -299,6 +295,7 @@ class Postgres(Dialect): exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, exp.Trim: trim_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index e16ea1d..a79a9f9 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, no_safe_divide_sql, rename_func, - str_position_sql, struct_extract_sql, timestrtotime_sql, ) @@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression): return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _concat_ws_sql(self, expression): - sep, *args = expression.expressions - sep = self.sql(sep) - if len(args) > 1: - return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})" - return f"ARRAY_JOIN({self.sql(args[0])}, {sep})" - - def _datatype_sql(self, expression): sql = self.datatype_sql(expression) if expression.this == exp.DataType.Type.TIMESTAMPTZ: @@ -61,7 +52,7 @@ def _initcap_sql(self, expression): def _decode_sql(self, expression): _ensure_utf8(expression.args.get("charset")) - return f"FROM_UTF8({self.sql(expression, 'this')})" + return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})" def _encode_sql(self, expression): @@ -119,6 +110,38 @@ def _ensure_utf8(charset): raise UnsupportedError(f"Unsupported charset {charset}") +def _approx_percentile(args): + if len(args) == 4: + return exp.ApproxQuantile( + this=seq_get(args, 0), + weight=seq_get(args, 1), + quantile=seq_get(args, 2), + accuracy=seq_get(args, 3), + ) + if len(args) == 3: + return exp.ApproxQuantile( + this=seq_get(args, 0), + quantile=seq_get(args, 1), + accuracy=seq_get(args, 2), + ) + return exp.ApproxQuantile.from_arg_list(args) + + +def _from_unixtime(args): + if len(args) == 3: + return exp.UnixToTime( + this=seq_get(args, 0), + hours=seq_get(args, 1), + minutes=seq_get(args, 2), + ) + if len(args) == 2: + return exp.UnixToTime( + this=seq_get(args, 0), + zone=seq_get(args, 1), + ) + return exp.UnixToTime.from_arg_list(args) + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" @@ -150,19 +173,25 @@ class Presto(Dialect): ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), - "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, - "STRPOS": exp.StrPosition.from_arg_list, + "FROM_UNIXTIME": _from_unixtime, + "STRPOS": lambda args: exp.StrPosition( + this=seq_get(args, 0), + substr=seq_get(args, 1), + instance=seq_get(args, 2), + ), "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "APPROX_PERCENTILE": _approx_percentile, "FROM_HEX": exp.Unhex.from_arg_list, "TO_HEX": exp.Hex.from_arg_list, "TO_UTF8": lambda args: exp.Encode( this=seq_get(args, 0), charset=exp.Literal.string("utf-8") ), "FROM_UTF8": lambda args: exp.Decode( - this=seq_get(args, 0), charset=exp.Literal.string("utf-8") + this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") ), } + FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() + FUNCTION_PARSERS.pop("TRIM") class Generator(generator.Generator): @@ -194,7 +223,6 @@ class Presto(Dialect): exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.ConcatWs: _concat_ws_sql, exp.DataType: _datatype_sql, exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", @@ -209,12 +237,13 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.LogicalOr: rename_func("BOOL_OR"), exp.Quantile: _quantile_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.SortArray: _no_sort_array, - exp.StrPosition: str_position_sql, + exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", @@ -233,6 +262,7 @@ class Presto(Dialect): exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", + exp.VariancePop: rename_func("VAR_POP"), } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 27dfb93..afd7913 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, transforms from sqlglot.dialects.dialect import rename_func from sqlglot.dialects.postgres import Postgres @@ -21,6 +23,19 @@ class Redshift(Postgres): "NVL": exp.Coalesce.from_arg_list, } + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: + this = super()._parse_types(check_func=check_func) + + if ( + isinstance(this, exp.DataType) + and this.this == exp.DataType.Type.VARCHAR + and this.expressions + and this.expressions[0] == exp.column("MAX") + ): + this.set("expressions", [exp.Var(this="MAX")]) + + return this + class Tokenizer(Postgres.Tokenizer): ESCAPES = ["\\"] @@ -52,6 +67,10 @@ class Redshift(Postgres): exp.DistStyleProperty, } + WITH_PROPERTIES = { + exp.LikeProperty, + } + TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore @@ -60,3 +79,57 @@ class Redshift(Postgres): exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.Matches: rename_func("DECODE"), } + + def values_sql(self, expression: exp.Values) -> str: + """ + Converts `VALUES...` expression into a series of unions. + + Note: If you have a lot of unions then this will result in a large number of recursive statements to + evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be + very slow. + """ + if not isinstance(expression.unnest().parent, exp.From): + return super().values_sql(expression) + rows = [tuple_exp.expressions for tuple_exp in expression.expressions] + selects = [] + for i, row in enumerate(rows): + if i == 0: + row = [ + exp.alias_(value, column_name) + for value, column_name in zip(row, expression.args["alias"].args["columns"]) + ] + selects.append(exp.Select(expressions=row)) + subquery_expression = selects[0] + if len(selects) > 1: + for select in selects[1:]: + subquery_expression = exp.union(subquery_expression, select, distinct=False) + return self.subquery_sql(subquery_expression.subquery(expression.alias)) + + def with_properties(self, properties: exp.Properties) -> str: + """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" + return self.properties(properties, prefix=" ", suffix="") + + def renametable_sql(self, expression: exp.RenameTable) -> str: + """Redshift only supports defining the table name itself (not the db) when renaming tables""" + expression = expression.copy() + target_table = expression.this + for arg in target_table.args: + if arg != "this": + target_table.set(arg, None) + this = self.sql(expression, "this") + return f"RENAME TO {this}" + + def datatype_sql(self, expression: exp.DataType) -> str: + """ + Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean + VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type + without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert + `TEXT` to `VARCHAR`. + """ + if expression.this == exp.DataType.Type.TEXT: + expression = expression.copy() + expression.set("this", exp.DataType.Type.VARCHAR) + precision = expression.args.get("expressions") + if not precision: + expression.append("expressions", exp.Var(this="MAX")) + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 24d3bdf..c44950a 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -210,6 +210,7 @@ class Snowflake(Dialect): **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DateAdd: rename_func("DATEADD"), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -218,7 +219,7 @@ class Snowflake(Dialect): exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}", exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), - exp.StrPosition: rename_func("POSITION"), + exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})", exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 7f05dea..42d34c2 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -124,6 +124,7 @@ class Spark(Hive): exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), + exp.LogicalOr: rename_func("BOOL_OR"), } TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index a0c4942..1b39449 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,6 +13,10 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType +def _fetch_sql(self, expression): + return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) + + # https://www.sqlite.org/lang_aggfunc.html#group_concat def _group_concat_sql(self, expression): this = expression.this @@ -30,6 +34,14 @@ def _group_concat_sql(self, expression): return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" +def _date_add_sql(self, expression): + modifier = expression.expression + modifier = expression.name if modifier.is_string else self.sql(modifier) + unit = expression.args.get("unit") + modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" + return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})" + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -71,6 +83,7 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + exp.DateAdd: _date_add_sql, exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, @@ -78,8 +91,11 @@ class SQLite(Dialect): exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.Levenshtein: rename_func("EDITDIST3"), exp.TableSample: no_tablesample_sql, + exp.DateStrToDate: lambda self, e: self.sql(e, "this"), + exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, exp.GroupConcat: _group_concat_sql, + exp.Fetch: _fetch_sql, } def transaction_sql(self, expression): diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py new file mode 100644 index 0000000..4340820 --- /dev/null +++ b/sqlglot/dialects/teradata.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from sqlglot import exp, generator, parser +from sqlglot.dialects.dialect import Dialect +from sqlglot.tokens import TokenType + + +class Teradata(Dialect): + class Parser(parser.Parser): + CHARSET_TRANSLATORS = { + "GRAPHIC_TO_KANJISJIS", + "GRAPHIC_TO_LATIN", + "GRAPHIC_TO_UNICODE", + "GRAPHIC_TO_UNICODE_PadSpace", + "KANJI1_KanjiEBCDIC_TO_UNICODE", + "KANJI1_KanjiEUC_TO_UNICODE", + "KANJI1_KANJISJIS_TO_UNICODE", + "KANJI1_SBC_TO_UNICODE", + "KANJISJIS_TO_GRAPHIC", + "KANJISJIS_TO_LATIN", + "KANJISJIS_TO_UNICODE", + "LATIN_TO_GRAPHIC", + "LATIN_TO_KANJISJIS", + "LATIN_TO_UNICODE", + "LOCALE_TO_UNICODE", + "UNICODE_TO_GRAPHIC", + "UNICODE_TO_GRAPHIC_PadGraphic", + "UNICODE_TO_GRAPHIC_VarGraphic", + "UNICODE_TO_KANJI1_KanjiEBCDIC", + "UNICODE_TO_KANJI1_KanjiEUC", + "UNICODE_TO_KANJI1_KANJISJIS", + "UNICODE_TO_KANJI1_SBC", + "UNICODE_TO_KANJISJIS", + "UNICODE_TO_LATIN", + "UNICODE_TO_LOCALE", + "UNICODE_TO_UNICODE_FoldSpace", + "UNICODE_TO_UNICODE_Fullwidth", + "UNICODE_TO_UNICODE_Halfwidth", + "UNICODE_TO_UNICODE_NFC", + "UNICODE_TO_UNICODE_NFD", + "UNICODE_TO_UNICODE_NFKC", + "UNICODE_TO_UNICODE_NFKD", + } + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, # type: ignore + "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), + } + + def _parse_translate(self, strict: bool) -> exp.Expression: + this = self._parse_conjunction() + + if not self._match(TokenType.USING): + self.raise_error("Expected USING in TRANSLATE") + + if self._match_texts(self.CHARSET_TRANSLATORS): + charset_split = self._prev.text.split("_TO_") + to = self.expression(exp.CharacterSet, this=charset_split[1]) + else: + self.raise_error("Expected a character set translator after USING in TRANSLATE") + + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + + # FROM before SET in Teradata UPDATE syntax + # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause + def _parse_update(self) -> exp.Expression: + return self.expression( + exp.Update, + **{ # type: ignore + "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), + "from": self._parse_from(), + "expressions": self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + "where": self._parse_where(), + }, + ) + + class Generator(generator.Generator): + # FROM before SET in Teradata UPDATE syntax + # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause + def update_sql(self, expression: exp.Update) -> str: + this = self.sql(expression, "this") + from_sql = self.sql(expression, "from") + set_sql = self.expressions(expression, flat=True) + where_sql = self.sql(expression, "where") + sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" + return self.prepend_ctes(expression, sql) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 465f534..9342e6b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -243,28 +243,34 @@ class TSQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "BIT": TokenType.BOOLEAN, - "REAL": TokenType.FLOAT, - "NTEXT": TokenType.TEXT, - "SMALLDATETIME": TokenType.DATETIME, "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, - "TIME": TokenType.TIMESTAMP, + "DECLARE": TokenType.COMMAND, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, - "SMALLMONEY": TokenType.SMALLMONEY, + "NTEXT": TokenType.TEXT, + "NVARCHAR(MAX)": TokenType.TEXT, + "PRINT": TokenType.COMMAND, + "REAL": TokenType.FLOAT, "ROWVERSION": TokenType.ROWVERSION, - "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, - "XML": TokenType.XML, + "SMALLDATETIME": TokenType.DATETIME, + "SMALLMONEY": TokenType.SMALLMONEY, "SQL_VARIANT": TokenType.VARIANT, - "NVARCHAR(MAX)": TokenType.TEXT, - "VARCHAR(MAX)": TokenType.TEXT, + "TIME": TokenType.TIMESTAMP, "TOP": TokenType.TOP, + "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, + "VARCHAR(MAX)": TokenType.TEXT, + "XML": TokenType.XML, } class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "CHARINDEX": exp.StrPosition.from_arg_list, + "CHARINDEX": lambda args: exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -288,7 +294,7 @@ class TSQL(Dialect): } # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table - TABLE_PREFIX_TOKENS = {TokenType.HASH} + TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER} def _parse_convert(self, strict): to = self._parse_types() diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index d093e29..be99fe2 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -653,6 +653,7 @@ class Create(Expression): "statistics": False, "no_primary_index": False, "indexes": False, + "no_schema_binding": False, } @@ -770,6 +771,10 @@ class AlterColumn(Expression): } +class RenameTable(Expression): + pass + + class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} @@ -804,7 +809,7 @@ class EncodeColumnConstraint(ColumnConstraintKind): class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT - arg_types = {"this": True, "expression": False} + arg_types = {"this": True, "start": False, "increment": False} class NotNullColumnConstraint(ColumnConstraintKind): @@ -1266,7 +1271,7 @@ class Tuple(Expression): class Subqueryable(Unionable): - def subquery(self, alias=None, copy=True): + def subquery(self, alias=None, copy=True) -> Subquery: """ Convert this expression to an aliased expression that can be used as a Subquery. @@ -1460,6 +1465,7 @@ class Unnest(UDTF): "expressions": True, "ordinality": False, "alias": False, + "offset": False, } @@ -2126,6 +2132,7 @@ class DataType(Expression): "this": True, "expressions": False, "nested": False, + "values": False, } class Type(AutoName): @@ -2134,6 +2141,8 @@ class DataType(Expression): VARCHAR = auto() NVARCHAR = auto() TEXT = auto() + MEDIUMTEXT = auto() + LONGTEXT = auto() BINARY = auto() VARBINARY = auto() INT = auto() @@ -2791,7 +2800,7 @@ class Day(Func): class Decode(Func): - arg_types = {"this": True, "charset": True} + arg_types = {"this": True, "charset": True, "replace": False} class DiToDate(Func): @@ -2815,7 +2824,7 @@ class Floor(Func): class Greatest(Func): - arg_types = {"this": True, "expressions": True} + arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -2861,7 +2870,7 @@ class JSONBExtractScalar(JSONExtract): class Least(Func): - arg_types = {"this": True, "expressions": True} + arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -2904,7 +2913,7 @@ class Lower(Func): class Map(Func): - arg_types = {"keys": True, "values": True} + arg_types = {"keys": False, "values": False} class VarMap(Func): @@ -2923,11 +2932,11 @@ class Matches(Func): class Max(AggFunc): - pass + arg_types = {"this": True, "expression": False} class Min(AggFunc): - pass + arg_types = {"this": True, "expression": False} class Month(Func): @@ -2962,7 +2971,7 @@ class QuantileIf(AggFunc): class ApproxQuantile(Quantile): - arg_types = {"this": True, "quantile": True, "accuracy": False} + arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} class ReadCSV(Func): @@ -3022,7 +3031,12 @@ class Substring(Func): class StrPosition(Func): - arg_types = {"substr": True, "this": True, "position": False} + arg_types = { + "this": True, + "substr": True, + "position": False, + "instance": False, + } class StrToDate(Func): @@ -3129,8 +3143,10 @@ class UnixToStr(Func): arg_types = {"this": True, "format": False} +# https://prestodb.io/docs/current/functions/datetime.html +# presto has weird zone/hours/minutes class UnixToTime(Func): - arg_types = {"this": True, "scale": False} + arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False} SECONDS = Literal.string("seconds") MILLIS = Literal.string("millis") @@ -3684,6 +3700,16 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: return identifier +@t.overload +def to_table(sql_path: str | Table, **kwargs) -> Table: + ... + + +@t.overload +def to_table(sql_path: None, **kwargs) -> None: + ... + + def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. @@ -3860,6 +3886,26 @@ def values( ) +def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: + """Build ALTER TABLE... RENAME... expression + + Args: + old_name: The old name of the table + new_name: The new name of the table + + Returns: + Alter table expression + """ + old_table = to_table(old_name) + new_table = to_table(new_name) + return AlterTable( + this=old_table, + actions=[ + RenameTable(this=new_table), + ], + ) + + def convert(value) -> Expression: """Convert a python value into an expression object. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 3935133..6375d92 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -82,6 +82,8 @@ class Generator: TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", + exp.DataType.Type.MEDIUMTEXT: "TEXT", + exp.DataType.Type.LONGTEXT: "TEXT", } TOKEN_MAPPING: t.Dict[TokenType, str] = {} @@ -105,6 +107,7 @@ class Generator: } WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) + SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" __slots__ = ( "time_mapping", @@ -211,6 +214,8 @@ class Generator: elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) + if self.pretty: + sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n") return sql def unsupported(self, message: str) -> None: @@ -401,7 +406,17 @@ class Generator: def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: - return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY" + start = expression.args.get("start") + start = f"START WITH {start}" if start else "" + increment = expression.args.get("increment") + increment = f"INCREMENT BY {increment}" if increment else "" + sequence_opts = "" + if start or increment: + sequence_opts = f"{start} {increment}" + sequence_opts = f" ({sequence_opts.strip()})" + return ( + f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}" + ) def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -475,10 +490,13 @@ class Generator: materialized, ) ) + no_schema_binding = ( + " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" + ) post_expression_modifiers = "".join((data, statistics, no_primary_index)) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}" + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -517,13 +535,19 @@ class Generator: type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) nested = "" interior = self.expressions(expression, flat=True) + values = "" if interior: - nested = ( - f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" - if expression.args.get("nested") - else f"({interior})" - ) - return f"{type_sql}{nested}" + if expression.args.get("nested"): + nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" + if expression.args.get("values") is not None: + delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") + values = ( + f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}" + ) + else: + nested = f"({interior})" + + return f"{type_sql}{nested}{values}" def directory_sql(self, expression: exp.Directory) -> str: local = "LOCAL " if expression.args.get("local") else "" @@ -622,10 +646,14 @@ class Generator: return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" - def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str: + def properties( + self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = "" + ) -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}" + return ( + f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}" + ) return "" def with_properties(self, properties: exp.Properties) -> str: @@ -763,14 +791,15 @@ class Generator: return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: - alias = self.sql(expression, "alias") args = self.expressions(expression) - if not alias: - return f"VALUES{self.seg('')}{args}" - alias = f" AS {alias}" if alias else alias - if self.WRAP_DERIVED_VALUES: - return f"(VALUES{self.seg('')}{args}){alias}" - return f"VALUES{self.seg('')}{args}{alias}" + alias = self.sql(expression, "alias") + values = f"VALUES{self.seg('')}{args}" + values = ( + f"({values})" + if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From)) + else values + ) + return f"{values} AS {alias}" if alias else values def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") @@ -868,6 +897,8 @@ class Generator: if self._replace_backslash: text = text.replace("\\", "\\\\") text = text.replace(self.quote_end, self._escaped_quote_end) + if self.pretty: + text = text.replace("\n", self.SENTINEL_LINE_BREAK) text = f"{self.quote_start}{text}{self.quote_end}" return text @@ -1036,7 +1067,9 @@ class Generator: alias = self.sql(expression, "alias") alias = f" AS {alias}" if alias else alias ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" - return f"UNNEST({args}){ordinality}{alias}" + offset = expression.args.get("offset") + offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else "" + return f"UNNEST({args}){ordinality}{alias}{offset}" def where_sql(self, expression: exp.Where) -> str: this = self.indent(self.sql(expression, "this")) @@ -1132,15 +1165,14 @@ class Generator: return f"EXTRACT({this} FROM {expression_sql})" def trim_sql(self, expression: exp.Trim) -> str: - target = self.sql(expression, "this") trim_type = self.sql(expression, "position") if trim_type == "LEADING": - return f"LTRIM({target})" + return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})" elif trim_type == "TRAILING": - return f"RTRIM({target})" + return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})" else: - return f"TRIM({target})" + return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})" def concat_sql(self, expression: exp.Concat) -> str: if len(expression.expressions) == 1: @@ -1317,6 +1349,10 @@ class Generator: return f"ALTER COLUMN {this} DROP DEFAULT" + def renametable_sql(self, expression: exp.RenameTable) -> str: + this = self.sql(expression, "this") + return f"RENAME TO {this}" + def altertable_sql(self, expression: exp.AlterTable) -> str: actions = expression.args["actions"] @@ -1326,7 +1362,7 @@ class Generator: actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Drop): actions = self.expressions(expression, "actions") - elif isinstance(actions[0], exp.AlterColumn): + elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)): actions = self.sql(actions[0]) else: self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 46b6b30..5258c2b 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -52,7 +52,10 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar If no schema is provided then the default schema defined at `sqlgot.schema` will be used db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement - rules (sequence): sequence of optimizer rules to use + rules (sequence): sequence of optimizer rules to use. + Many of the rules require tables and columns to be qualified. + Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know + what you're doing! **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index f4568c2..8da4e43 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -1,7 +1,7 @@ import itertools from sqlglot import alias, exp -from sqlglot.errors import OptimizeError +from sqlglot.errors import OptimizeError, SchemaError from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -382,7 +382,7 @@ class _Resolver: try: return self.schema.column_names(source, only_visible) except Exception as e: - raise OptimizeError(str(e)) from e + raise SchemaError(str(e)) from e if isinstance(source, Scope) and isinstance(source.expression, exp.Values): return source.expression.alias_column_names diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bd95db8..c97b19a 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -107,6 +107,8 @@ class Parser(metaclass=_Parser): TokenType.VARCHAR, TokenType.NVARCHAR, TokenType.TEXT, + TokenType.MEDIUMTEXT, + TokenType.LONGTEXT, TokenType.BINARY, TokenType.VARBINARY, TokenType.JSON, @@ -233,6 +235,7 @@ class Parser(metaclass=_Parser): TokenType.UNPIVOT, TokenType.PROPERTIES, TokenType.PROCEDURE, + TokenType.VIEW, TokenType.VOLATILE, TokenType.WINDOW, *SUBQUERY_PREDICATES, @@ -252,6 +255,7 @@ class Parser(metaclass=_Parser): TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} FUNC_TOKENS = { + TokenType.COMMAND, TokenType.CURRENT_DATE, TokenType.CURRENT_DATETIME, TokenType.CURRENT_TIMESTAMP, @@ -552,7 +556,7 @@ class Parser(metaclass=_Parser): TokenType.IF: lambda self: self._parse_if(), } - FUNCTION_PARSERS = { + FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "TRY_CONVERT": lambda self: self._parse_convert(False), "EXTRACT": lambda self: self._parse_extract(), @@ -937,6 +941,7 @@ class Parser(metaclass=_Parser): statistics = None no_primary_index = None indexes = None + no_schema_binding = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function() @@ -975,6 +980,9 @@ class Parser(metaclass=_Parser): break else: indexes.append(index) + elif create_token.token_type == TokenType.VIEW: + if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"): + no_schema_binding = True return self.expression( exp.Create, @@ -993,6 +1001,7 @@ class Parser(metaclass=_Parser): statistics=statistics, no_primary_index=no_primary_index, indexes=indexes, + no_schema_binding=no_schema_binding, ) def _parse_property(self) -> t.Optional[exp.Expression]: @@ -1246,8 +1255,14 @@ class Parser(metaclass=_Parser): return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) def _parse_value(self) -> exp.Expression: - expressions = self._parse_wrapped_csv(self._parse_conjunction) - return self.expression(exp.Tuple, expressions=expressions) + if self._match(TokenType.L_PAREN): + expressions = self._parse_csv(self._parse_conjunction) + self._match_r_paren() + return self.expression(exp.Tuple, expressions=expressions) + + # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. + # Source: https://prestodb.io/docs/current/sql/values.html + return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) def _parse_select( self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True @@ -1313,19 +1328,9 @@ class Parser(metaclass=_Parser): # Union ALL should be a property of the top select node, not the subquery return self._parse_subquery(this, parse_alias=parse_subquery_alias) elif self._match(TokenType.VALUES): - if self._curr.token_type == TokenType.L_PAREN: - # We don't consume the left paren because it's consumed in _parse_value - expressions = self._parse_csv(self._parse_value) - else: - # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. - # Source: https://prestodb.io/docs/current/sql/values.html - expressions = self._parse_csv( - lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) - ) - this = self.expression( exp.Values, - expressions=expressions, + expressions=self._parse_csv(self._parse_value), alias=self._parse_table_alias(), ) else: @@ -1612,13 +1617,12 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) - if self._match(TokenType.WITH): + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): this.set( "hints", - self._parse_wrapped_csv( - lambda: self._parse_function() or self._parse_var(any_token=True) - ), + self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)), ) + self._match_r_paren() if not self.alias_post_tablesample: table_sample = self._parse_table_sample() @@ -1643,8 +1647,17 @@ class Parser(metaclass=_Parser): alias.set("columns", [alias.this]) alias.set("this", None) + offset = None + if self._match_pair(TokenType.WITH, TokenType.OFFSET): + self._match(TokenType.ALIAS) + offset = self._parse_conjunction() + return self.expression( - exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias + exp.Unnest, + expressions=expressions, + ordinality=ordinality, + alias=alias, + offset=offset, ) def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: @@ -1999,7 +2012,7 @@ class Parser(metaclass=_Parser): this = self._parse_column() if type_token: - if this: + if this and not isinstance(this, exp.Star): return self.expression(exp.Cast, this=this, to=type_token) if not type_token.args.get("expressions"): self._retreat(index) @@ -2050,6 +2063,7 @@ class Parser(metaclass=_Parser): self._retreat(index) return None + values: t.Optional[t.List[t.Optional[exp.Expression]]] = None if nested and self._match(TokenType.LT): if is_struct: expressions = self._parse_csv(self._parse_struct_kwargs) @@ -2059,6 +2073,10 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") + if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)): + values = self._parse_csv(self._parse_conjunction) + self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN)) + value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: @@ -2097,9 +2115,13 @@ class Parser(metaclass=_Parser): this=exp.DataType.Type[type_token.value.upper()], expressions=expressions, nested=nested, + values=values, ) def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: + if self._curr and self._curr.token_type in self.TYPE_TOKENS: + return self._parse_types() + this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() @@ -2412,6 +2434,14 @@ class Parser(metaclass=_Parser): self._match(TokenType.ALWAYS) kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) + + if self._match(TokenType.L_PAREN): + if self._match_text_seq("START", "WITH"): + kind.set("start", self._parse_bitwise()) + if self._match_text_seq("INCREMENT", "BY"): + kind.set("increment", self._parse_bitwise()) + + self._match_r_paren() else: return this @@ -2619,8 +2649,12 @@ class Parser(metaclass=_Parser): if self._match(TokenType.IN): args.append(self._parse_bitwise()) - # Note: we're parsing in order needle, haystack, position - this = exp.StrPosition.from_arg_list(args) + this = exp.StrPosition( + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ) + self.validate_expression(this, args) return this @@ -2999,6 +3033,8 @@ class Parser(metaclass=_Parser): actions = self._parse_csv(self._parse_add_column) elif self._match_text_seq("DROP", advance=False): actions = self._parse_csv(self._parse_drop_column) + elif self._match_text_seq("RENAME", "TO"): + actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True)) elif self._match_text_seq("ALTER"): self._match(TokenType.COLUMN) column = self._parse_field(any_token=True) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 8e312a7..f12528f 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -82,6 +82,8 @@ class TokenType(AutoName): VARCHAR = auto() NVARCHAR = auto() TEXT = auto() + MEDIUMTEXT = auto() + LONGTEXT = auto() BINARY = auto() VARBINARY = auto() JSON = auto() @@ -434,6 +436,8 @@ class Tokenizer(metaclass=_Tokenizer): ESCAPES = ["'"] + _ESCAPES: t.Set[str] = set() + KEYWORDS = { **{ f"{key}{postfix}": TokenType.BLOCK_START @@ -461,6 +465,7 @@ class Tokenizer(metaclass=_Tokenizer): "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, "ALL": TokenType.ALL, + "ALWAYS": TokenType.ALWAYS, "AND": TokenType.AND, "ANTI": TokenType.ANTI, "ANY": TokenType.ANY, @@ -472,6 +477,7 @@ class Tokenizer(metaclass=_Tokenizer): "BETWEEN": TokenType.BETWEEN, "BOTH": TokenType.BOTH, "BUCKET": TokenType.BUCKET, + "BY DEFAULT": TokenType.BY_DEFAULT, "CACHE": TokenType.CACHE, "UNCACHE": TokenType.UNCACHE, "CASE": TokenType.CASE, @@ -521,9 +527,11 @@ class Tokenizer(metaclass=_Tokenizer): "FOREIGN KEY": TokenType.FOREIGN_KEY, "FORMAT": TokenType.FORMAT, "FROM": TokenType.FROM, + "GENERATED": TokenType.GENERATED, "GROUP BY": TokenType.GROUP_BY, "GROUPING SETS": TokenType.GROUPING_SETS, "HAVING": TokenType.HAVING, + "IDENTITY": TokenType.IDENTITY, "IF": TokenType.IF, "ILIKE": TokenType.ILIKE, "IMMUTABLE": TokenType.IMMUTABLE, @@ -746,7 +754,7 @@ class Tokenizer(metaclass=_Tokenizer): ) def __init__(self) -> None: - self._replace_backslash = "\\" in self._ESCAPES # type: ignore + self._replace_backslash = "\\" in self._ESCAPES self.reset() def reset(self) -> None: @@ -771,7 +779,10 @@ class Tokenizer(metaclass=_Tokenizer): self.reset() self.sql = sql self.size = len(sql) + self._scan() + return self.tokens + def _scan(self, until: t.Optional[t.Callable] = None) -> None: while self.size and not self._end: self._start = self._current self._advance() @@ -792,7 +803,9 @@ class Tokenizer(metaclass=_Tokenizer): self._scan_identifier(identifier_end) else: self._scan_keywords() - return self.tokens + + if until and until(): + break def _chars(self, size: int) -> str: if size == 1: @@ -832,11 +845,13 @@ class Tokenizer(metaclass=_Tokenizer): if token_type in self.COMMANDS and ( len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON ): - self._start = self._current - while not self._end and self._peek != ";": - self._advance() - if self._start < self._current: - self._add(TokenType.STRING) + start = self._current + tokens = len(self.tokens) + self._scan(lambda: self._peek == ";") + self.tokens = self.tokens[:tokens] + text = self.sql[start : self._current].strip() + if text: + self._add(TokenType.STRING, text) def _scan_keywords(self) -> None: size = 0 @@ -947,7 +962,8 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek.isidentifier(): # type: ignore number_text = self._text literal = [] - while self._peek.isidentifier(): # type: ignore + + while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore literal.append(self._peek.upper()) # type: ignore self._advance() @@ -1063,8 +1079,12 @@ class Tokenizer(metaclass=_Tokenizer): delim_size = len(delimiter) while True: - if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore - text += delimiter + if ( + self._char in self._ESCAPES + and self._peek + and (self._peek == delimiter or self._peek in self._ESCAPES) + ): + text += self._peek self._advance(2) else: if self._chars(delim_size) == delimiter: diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index c61a2f3..e5b1c94 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,6 +6,8 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + self.validate_identity("SELECT STRUCT>(['2023-01-17'])") + self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, @@ -41,6 +43,15 @@ class TestBigQuery(Validator): "spark": r"'/\\*.*\\*/'", }, ) + self.validate_all( + r"'\\'", + write={ + "bigquery": r"'\\'", + "duckdb": r"'\'", + "presto": r"'\'", + "hive": r"'\\'", + }, + ) self.validate_all( R'R"""/\*.*\*/"""', write={ diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 109e9f3..2827dd4 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -17,6 +17,7 @@ class TestClickhouse(Validator): self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") + self.validate_identity("position(a, b)") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -47,3 +48,9 @@ class TestClickhouse(Validator): "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", }, ) + + def test_cte(self): + self.validate_identity("WITH 'x' AS foo SELECT foo") + self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts") + self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") + self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 284a30d..b2f4676 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -14,7 +14,7 @@ class Validator(unittest.TestCase): self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) return expression - def validate_all(self, sql, read=None, write=None, pretty=False): + def validate_all(self, sql, read=None, write=None, pretty=False, identify=False): """ Validate that: 1. Everything in `read` transpiles to `sql` @@ -32,7 +32,10 @@ class Validator(unittest.TestCase): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( parse_one(read_sql, read_dialect).sql( - self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty + self.dialect, + unsupported_level=ErrorLevel.IGNORE, + pretty=pretty, + identify=identify, ), sql, ) @@ -48,6 +51,7 @@ class Validator(unittest.TestCase): write_dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty, + identify=identify, ), write_sql, ) @@ -76,7 +80,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS CLOB)", "postgres": "CAST(a AS TEXT)", "presto": "CAST(a AS VARCHAR)", - "redshift": "CAST(a AS TEXT)", + "redshift": "CAST(a AS VARCHAR(MAX))", "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", @@ -155,7 +159,7 @@ class TestDialect(Validator): "oracle": "CAST(a AS CLOB)", "postgres": "CAST(a AS TEXT)", "presto": "CAST(a AS VARCHAR)", - "redshift": "CAST(a AS TEXT)", + "redshift": "CAST(a AS VARCHAR(MAX))", "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", @@ -344,6 +348,7 @@ class TestDialect(Validator): "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "CAST('2020-01-01' AS TIMESTAMP)", + "sqlite": "'2020-01-01'", }, ) self.validate_all( @@ -373,7 +378,7 @@ class TestDialect(Validator): "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", - "redshift": "CAST(x AS TEXT)", + "redshift": "CAST(x AS VARCHAR(MAX))", }, ) self.validate_all( @@ -488,7 +493,9 @@ class TestDialect(Validator): "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "postgres": "x + INTERVAL '1' 'day'", "presto": "DATE_ADD('day', 1, x)", + "snowflake": "DATEADD(x, 1, 'day')", "spark": "DATE_ADD(x, 1)", + "sqlite": "DATE(x, '1 day')", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "tsql": "DATEADD(day, 1, x)", }, @@ -594,6 +601,7 @@ class TestDialect(Validator): "hive": "TO_DATE(x)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", "spark": "TO_DATE(x)", + "sqlite": "x", }, ) self.validate_all( @@ -955,7 +963,7 @@ class TestDialect(Validator): }, ) self.validate_all( - "STR_POSITION('a', x)", + "STR_POSITION(x, 'a')", write={ "drill": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')", @@ -971,7 +979,7 @@ class TestDialect(Validator): "POSITION('a', x, 3)", write={ "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", - "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "presto": "STRPOS(x, 'a', 3)", "spark": "LOCATE('a', x, 3)", "clickhouse": "position(x, 'a', 3)", "snowflake": "POSITION('a', x, 3)", @@ -982,9 +990,10 @@ class TestDialect(Validator): "CONCAT_WS('-', 'a', 'b')", write={ "duckdb": "CONCAT_WS('-', 'a', 'b')", - "presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')", + "presto": "CONCAT_WS('-', 'a', 'b')", "hive": "CONCAT_WS('-', 'a', 'b')", "spark": "CONCAT_WS('-', 'a', 'b')", + "trino": "CONCAT_WS('-', 'a', 'b')", }, ) @@ -992,9 +1001,10 @@ class TestDialect(Validator): "CONCAT_WS('-', x)", write={ "duckdb": "CONCAT_WS('-', x)", - "presto": "ARRAY_JOIN(x, '-')", "hive": "CONCAT_WS('-', x)", + "presto": "CONCAT_WS('-', x)", "spark": "CONCAT_WS('-', x)", + "trino": "CONCAT_WS('-', x)", }, ) self.validate_all( @@ -1118,6 +1128,7 @@ class TestDialect(Validator): self.validate_all( "SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY", write={ + "sqlite": "SELECT x FROM y LIMIT 3 OFFSET 10", "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", }, ) @@ -1197,7 +1208,7 @@ class TestDialect(Validator): "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", - "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))", + "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 VARCHAR(MAX), c2 VARCHAR(1024))", }, ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index bbf00b1..d485593 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -356,6 +356,30 @@ class TestHive(Validator): "spark": "SELECT a_b AS 1_a FROM test_table", }, ) + self.validate_all( + "SELECT 1a_1a FROM test_a", + write={ + "spark": "SELECT 1a_1a FROM test_a", + }, + ) + self.validate_all( + "SELECT 1a AS 1a_1a FROM test_a", + write={ + "spark": "SELECT 1a AS 1a_1a FROM test_a", + }, + ) + self.validate_all( + "CREATE TABLE test_table (1a STRING)", + write={ + "spark": "CREATE TABLE test_table (1a STRING)", + }, + ) + self.validate_all( + "CREATE TABLE test_table2 (1a_1a STRING)", + write={ + "spark": "CREATE TABLE test_table2 (1a_1a STRING)", + }, + ) self.validate_all( "PERCENTILE(x, 0.5)", write={ @@ -420,7 +444,7 @@ class TestHive(Validator): "LOCATE('a', x, 3)", write={ "duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", - "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "presto": "STRPOS(x, 'a', 3)", "hive": "LOCATE('a', x, 3)", "spark": "LOCATE('a', x, 3)", }, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 7cd686d..dfd2f8e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -65,6 +65,17 @@ class TestMySQL(Validator): self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") self.validate_identity("SELECT SCHEMA()") + def test_types(self): + self.validate_all( + "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)", + read={ + "mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)", + }, + write={ + "spark": "CAST(x AS TEXT) + CAST(y AS TEXT)", + }, + ) + def test_canonical_functions(self): self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 583d349..2351e3b 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -46,14 +46,6 @@ class TestPostgres(Validator): " CONSTRAINT valid_discount CHECK (price > discounted_price))" }, ) - self.validate_all( - "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)", - write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"}, - ) - self.validate_all( - "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)", - write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"}, - ) with self.assertRaises(ParseError): transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ee535e9..195e382 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -152,6 +152,10 @@ class TestPresto(Validator): "spark": "FROM_UNIXTIME(x)", }, ) + self.validate_identity("FROM_UNIXTIME(a, b)") + self.validate_identity("FROM_UNIXTIME(a, b, c)") + self.validate_identity("TRIM(a, b)") + self.validate_identity("VAR_POP(a)") self.validate_all( "TO_UNIXTIME(x)", write={ @@ -302,6 +306,7 @@ class TestPresto(Validator): ) def test_presto(self): + self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") self.validate_all( 'SELECT a."b" FROM "foo"', write={ @@ -443,8 +448,10 @@ class TestPresto(Validator): "spark": UnsupportedError, }, ) + self.validate_identity("SELECT * FROM (VALUES (1))") self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") + self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") def test_encode_decode(self): self.validate_all( @@ -459,6 +466,12 @@ class TestPresto(Validator): "spark": "DECODE(x, 'utf-8')", }, ) + self.validate_all( + "FROM_UTF8(x, y)", + write={ + "presto": "FROM_UTF8(x, y)", + }, + ) self.validate_all( "ENCODE(x, 'utf-8')", write={ diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index f650c98..e20661e 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -89,7 +89,9 @@ class TestRedshift(Validator): self.validate_identity( "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" ) - self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL") + self.validate_identity( + "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL" + ) self.validate_identity( "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" ) @@ -102,3 +104,81 @@ class TestRedshift(Validator): self.validate_identity( "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" ) + + def test_values(self): + self.validate_all( + "SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)", + write={ + "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t", + }, + ) + self.validate_all( + "SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", + write={ + "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + }, + ) + self.validate_all( + "SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)", + write={ + "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t", + }, + ) + self.validate_all( + "INSERT INTO t(a) VALUES (1), (2), (3)", + write={ + "redshift": "INSERT INTO t (a) VALUES (1), (2), (3)", + }, + ) + self.validate_all( + "INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", + write={ + "redshift": "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + }, + ) + self.validate_all( + "INSERT INTO t(a, b) VALUES (1, 2), (3, 4)", + write={ + "redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)", + }, + ) + + def test_create_table_like(self): + self.validate_all( + "CREATE TABLE t1 LIKE t2", + write={ + "redshift": "CREATE TABLE t1 (LIKE t2)", + }, + ) + self.validate_all( + "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL", + write={ + "redshift": "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL", + }, + ) + + def test_rename_table(self): + self.validate_all( + "ALTER TABLE db.t1 RENAME TO db.t2", + write={ + "spark": "ALTER TABLE db.t1 RENAME TO db.t2", + "redshift": "ALTER TABLE db.t1 RENAME TO t2", + }, + ) + + def test_varchar_max(self): + self.validate_all( + "CREATE TABLE TEST (cola VARCHAR(MAX))", + write={ + "redshift": 'CREATE TABLE "TEST" ("cola" VARCHAR(MAX))', + }, + identify=True, + ) + + def test_no_schema_binding(self): + self.validate_all( + "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", + write={ + "redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index f287a89..fad858c 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -307,5 +307,12 @@ TBLPROPERTIES ( def test_iif(self): self.validate_all( - "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"} + "SELECT IIF(cond, 'True', 'False')", + write={"spark": "SELECT IF(cond, 'True', 'False')"}, + ) + + def test_bool_or(self): + self.validate_all( + "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", + write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py new file mode 100644 index 0000000..e56de25 --- /dev/null +++ b/tests/dialects/test_teradata.py @@ -0,0 +1,23 @@ +from tests.dialects.test_dialect import Validator + + +class TestTeradata(Validator): + dialect = "teradata" + + def test_translate(self): + self.validate_all( + "TRANSLATE(x USING LATIN_TO_UNICODE)", + write={ + "teradata": "CAST(x AS CHAR CHARACTER SET UNICODE)", + }, + ) + self.validate_identity("CAST(x AS CHAR CHARACTER SET UNICODE)") + + def test_update(self): + self.validate_all( + "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1", + write={ + "teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1", + "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b74c05f..d2972ca 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -5,6 +5,13 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'") + self.validate_identity("PRINT @TestVariable") + self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar") + self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)") + self.validate_identity( + "SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID" + ) self.validate_identity('SELECT "x"."y" FROM foo') self.validate_identity("SELECT * FROM #foo") self.validate_identity("SELECT * FROM ##foo") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index beb5703..4e21d2b 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -59,6 +59,8 @@ map.x SELECT call.x a.b.INT(1.234) INT(x / 100) +time * 100 +int * 100 x IN (-1, 1) x IN ('a', 'a''a') x IN ((1)) @@ -69,6 +71,11 @@ x IS TRUE x IS FALSE x IS TRUE IS TRUE x LIKE y IS TRUE +MAP() +GREATEST(x) +LEAST(y) +MAX(a, b) +MIN(a, b) time zone ARRAY @@ -133,6 +140,7 @@ x AT TIME ZONE 'UTC' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' SET x = 1 SET -v +SET x = ';' COMMIT USE db NOT 1 @@ -170,6 +178,7 @@ SELECT COUNT(DISTINCT a, b) SELECT COUNT(DISTINCT a, b + 1) SELECT SUM(DISTINCT x) SELECT SUM(x IGNORE NULLS) AS x +SELECT TRUNCATE(a, b) SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x @@ -622,7 +631,7 @@ SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ -SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ +SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */ SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' SELECT x AS INTO FROM bla SELECT * INTO newevent FROM event @@ -643,3 +652,21 @@ ALTER TABLE integers ALTER COLUMN i DROP DEFAULT ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT SELECT div.a FROM test_table AS div +WITH view AS (SELECT 1 AS x) SELECT * FROM view +CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA +CREATE TABLE asd AS SELECT asd FROM asd WITH DATA +ARRAY>> +ARRAY[1, 2, 3] +ARRAY[] +STRUCT +STRUCT("bla") +STRUCT("bla") +STRUCT(5) +STRUCT("2011-05-05") +STRUCT(1, t.str_col) +SELECT CAST(NULL AS ARRAY) IS NULL AS array_is_null +CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) +CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY) +CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1)) +CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1)) +CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 067fe77..64806eb 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -322,3 +322,23 @@ SELECT * /* multi line comment */; +WITH table_data AS ( + SELECT 'bob' AS name, ARRAY['banana', 'apple', 'orange'] AS fruit_basket +) +SELECT + name, + fruit, + basket_index +FROM table_data +CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET basket_index; +WITH table_data AS ( + SELECT + 'bob' AS name, + ARRAY('banana', 'apple', 'orange') AS fruit_basket +) +SELECT + name, + fruit, + basket_index +FROM table_data +CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET AS basket_index; diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 906e08c..9e5f988 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -624,6 +624,10 @@ FROM foo""", self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) with self.assertRaises(ValueError): exp.to_table(1) + empty_string = exp.to_table("") + self.assertEqual(empty_string.name, "") + self.assertIsNone(table_only.args.get("db")) + self.assertIsNone(table_only.args.get("catalog")) def test_to_column(self): column_only = exp.to_column("column_name") @@ -715,3 +719,9 @@ FROM foo""", self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT") self.assertEqual(exp.DataType.build("NULL").sql(), "NULL") self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN") + + def test_rename_table(self): + self.assertEqual( + exp.rename_table("t1", "t2").sql(), + "ALTER TABLE t1 RENAME TO t2", + ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 887f427..af21679 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -6,7 +6,7 @@ from pandas.testing import assert_frame_equal import sqlglot from sqlglot import exp, optimizer, parse_one -from sqlglot.errors import OptimizeError +from sqlglot.errors import OptimizeError, SchemaError from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from sqlglot.schema import MappingSchema @@ -161,7 +161,7 @@ class TestOptimizer(unittest.TestCase): def test_qualify_columns__invalid(self): for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): with self.subTest(sql): - with self.assertRaises(OptimizeError): + with self.assertRaises((OptimizeError, SchemaError)): optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema) def test_lower_identities(self): diff --git a/tests/test_parser.py b/tests/test_parser.py index 03b801b..dbde437 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -325,3 +325,9 @@ class TestParser(unittest.TestCase): "Expected table name", logger, ) + + def test_rename_table(self): + self.assertEqual( + parse_one("ALTER TABLE foo RENAME TO bar").sql(), + "ALTER TABLE foo RENAME TO bar", + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 3a7fea4..3e094f5 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -272,6 +272,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", "WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2", "WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2", ) + self.validate( + "SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)", + "SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)", + write="presto", + ) def test_alter(self): self.validate( @@ -447,6 +452,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", self.assertEqual(generated, pretty) self.assertEqual(parse_one(sql), parse_one(pretty)) + def test_pretty_line_breaks(self): + self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'") + @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): invalid = "x + 1. (" -- cgit v1.2.3