diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-04 12:14:45 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-10-04 12:14:45 +0000 |
commit | a34653eb21369376f0e054dd989311afcb167f5b (patch) | |
tree | 5a0280adce195af0be654f79fd99395fd2932c19 /sqlglot | |
parent | Releasing debian version 18.7.0-1. (diff) | |
download | sqlglot-a34653eb21369376f0e054dd989311afcb167f5b.tar.xz sqlglot-a34653eb21369376f0e054dd989311afcb167f5b.zip |
Merging upstream version 18.11.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 33 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 107 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 47 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 9 | ||||
-rw-r--r-- | sqlglot/executor/env.py | 1 | ||||
-rw-r--r-- | sqlglot/expressions.py | 81 | ||||
-rw-r--r-- | sqlglot/generator.py | 67 | ||||
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 115 | ||||
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 19 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize_identifiers.py | 18 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 97 | ||||
-rw-r--r-- | sqlglot/parser.py | 87 | ||||
-rw-r--r-- | sqlglot/tokens.py | 18 |
24 files changed, 701 insertions, 153 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1349c56..0d741b5 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -190,6 +190,16 @@ class BigQuery(Dialect): "%D": "%m/%d/%y", } + ESCAPE_SEQUENCES = { + "\\a": "\a", + "\\b": "\b", + "\\f": "\f", + "\\n": "\n", + "\\r": "\r", + "\\t": "\t", + "\\v": "\v", + } + FORMAT_MAPPING = { "DD": "%d", "MM": "%m", @@ -212,15 +222,14 @@ class BigQuery(Dialect): @classmethod def normalize_identifier(cls, expression: E) -> E: - # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). - # The following check is essentially a heuristic to detect tables based on whether or - # not they're qualified. if isinstance(expression, exp.Identifier): parent = expression.parent - while isinstance(parent, exp.Dot): parent = parent.parent + # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). + # The following check is essentially a heuristic to detect tables based on whether or + # not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive. if ( not isinstance(parent, exp.UserDefinedFunction) and not (isinstance(parent, exp.Table) and parent.db) @@ -419,6 +428,7 @@ class BigQuery(Dialect): RENAME_TABLE_WITH_DB = False NVL2_SUPPORTED = False UNNEST_WITH_ORDINALITY = False + COLLATE_IS_FUNC = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -520,18 +530,6 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore - { - "\a": "\\a", - "\b": "\\b", - "\f": "\\f", - "\n": "\\n", - "\r": "\\r", - "\t": "\\t", - "\v": "\\v", - } - ) - # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords RESERVED_KEYWORDS = { *generator.Generator.RESERVED_KEYWORDS, diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 7446081..e9d9326 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, inline_array_sql, @@ -21,18 +21,33 @@ def _lower_func(sql: str) -> str: return sql[:index].lower() + sql[index:] +def _quantile_sql(self, e): + quantile = e.args["quantile"] + args = f"({self.sql(e, 'this')})" + if isinstance(quantile, exp.Array): + func = self.func("quantiles", *quantile) + else: + func = self.func("quantile", quantile) + return func + args + + class ClickHouse(Dialect): NORMALIZE_FUNCTIONS: bool | str = False NULL_ORDERING = "nulls_are_last" STRICT_STRING_CONCAT = True SUPPORTS_USER_DEFINED_TYPES = False + ESCAPE_SEQUENCES = { + "\\0": "\0", + } + class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] STRING_ESCAPES = ["'", "\\"] BIT_STRINGS = [("0b", "")] HEX_STRINGS = [("0x", ""), ("0X", "")] + HEREDOC_STRINGS = ["$"] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -55,6 +70,7 @@ class ClickHouse(Dialect): "LOWCARDINALITY": TokenType.LOWCARDINALITY, "MAP": TokenType.MAP, "NESTED": TokenType.NESTED, + "SAMPLE": TokenType.TABLE_SAMPLE, "TUPLE": TokenType.STRUCT, "UINT128": TokenType.UINT128, "UINT16": TokenType.USMALLINT, @@ -64,6 +80,11 @@ class ClickHouse(Dialect): "UINT8": TokenType.UTINYINT, } + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.HEREDOC_STRING, + } + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -301,6 +322,7 @@ class ClickHouse(Dialect): QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") NVL2_SUPPORTED = False + TABLESAMPLE_REQUIRES_PARENS = False STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -348,6 +370,7 @@ class ClickHouse(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.AnyValue: rename_func("any"), exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, @@ -359,12 +382,13 @@ class ClickHouse(Dialect): "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + exp.IsNan: rename_func("isNaN"), exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, - exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile")) - + f"({self.sql(e, 'this')})", + exp.Quantile: _quantile_sql, exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", + exp.StartsWith: rename_func("startsWith"), exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 39daad7..a044bc0 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -51,6 +51,26 @@ class Databricks(Spark): exp.ToChar: lambda self, e: self.function_fallback_sql(e), } + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: + constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint) + kind = expression.args.get("kind") + if ( + constraint + and isinstance(kind, exp.DataType) + and kind.this in exp.DataType.INTEGER_TYPES + ): + # only BIGINT generated identity constraints are supported + expression = expression.copy() + expression.set("kind", exp.DataType.build("bigint")) + return super().columndef_sql(expression, sep) + + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + expression = expression.copy() + expression.set("this", True) # trigger ALWAYS in super class + return super().generatedasidentitycolumnconstraint_sql(expression) + class Tokenizer(Spark.Tokenizer): HEX_STRINGS = [] diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index ccf04da..bd839af 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -81,6 +81,8 @@ class _Dialect(type): klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) + klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} + klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) klass.parser_class = getattr(klass, "Parser", Parser) klass.generator_class = getattr(klass, "Generator", Generator) @@ -188,6 +190,9 @@ class Dialect(metaclass=_Dialect): # special syntax cast(x as date format 'yyyy') defaults to time_mapping FORMAT_MAPPING: t.Dict[str, str] = {} + # Mapping of an unescaped escape sequence to the corresponding character + ESCAPE_SEQUENCES: t.Dict[str, str] = {} + # Columns that are auto-generated by the engine corresponding to this dialect # Such columns may be excluded from SELECT * queries, for example PSEUDOCOLUMNS: t.Set[str] = set() @@ -204,6 +209,8 @@ class Dialect(metaclass=_Dialect): INVERSE_TIME_MAPPING: t.Dict[str, str] = {} INVERSE_TIME_TRIE: t.Dict = {} + INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} + def __eq__(self, other: t.Any) -> bool: return type(self) == other @@ -245,7 +252,7 @@ class Dialect(metaclass=_Dialect): """ Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, - they will be normalized regardless of being quoted or not. + they will be normalized to lowercase regardless of being quoted or not. """ if isinstance(expression, exp.Identifier) and ( not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index a427870..3f925a7 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -51,6 +51,32 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") +def _create_sql(self, expression: exp.Create) -> str: + expression = expression.copy() + + # remove UNIQUE column constraints + for constraint in expression.find_all(exp.UniqueColumnConstraint): + if constraint.parent: + constraint.parent.pop() + + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + # CTAS with temp tables map to CREATE TEMPORARY VIEW + kind = expression.args["kind"] + if kind.upper() == "TABLE" and temporary: + if expression.expression: + return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" + else: + # CREATE TEMPORARY TABLE may require storage provider + expression = self.temporary_storage_provider(expression) + + return create_with_partitions_sql(self, expression) + + def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) @@ -429,7 +455,7 @@ class Hive(Dialect): if e.args.get("allow_null") else "NOT NULL", exp.VarMap: var_map_sql, - exp.Create: create_with_partitions_sql, + exp.Create: _create_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpExtract: regexp_extract_sql, @@ -478,8 +504,13 @@ class Hive(Dialect): exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED, } + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: + # Hive has no temporary storage provider (there are hive settings though) + return expression + def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") parent = expression.parent diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 554241d..59a0a2a 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate: return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: +def _str_to_date_sql( + self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate +) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" @@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql( + kind: str, +) -> t.Callable[[MySQL.Generator, exp.Expression], str]: + def func(self: MySQL.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D return func +def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: + time_format = expression.args.get("format") + if time_format: + return _str_to_date_sql(self, expression) + return f"DATE({self.sql(expression, 'this')})" + + +def _remove_ts_or_ds_to_date( + to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None, + args: t.Tuple[str, ...] = ("this",), +) -> t.Callable[[MySQL.Generator, exp.Func], str]: + def func(self: MySQL.Generator, expression: exp.Func) -> str: + expression = expression.copy() + + for arg_key in args: + arg = expression.args.get(arg_key) + if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): + expression.set(arg_key, arg.this) + + return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression) + + return func + + class MySQL(Dialect): # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html IDENTIFIERS_CAN_START_WITH_DIGIT = True @@ -233,6 +261,7 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), @@ -240,14 +269,33 @@ class MySQL(Dialect): "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, "MONTHNAME": lambda args: exp.TimeToStr( - this=seq_get(args, 0), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), format=exp.Literal.string("%B"), ), "STR_TO_DATE": _str_to_date, + "TO_DAYS": lambda args: exp.paren( + exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")), + unit=exp.var("DAY"), + ) + + 1 + ), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "WEEK": lambda args: exp.Week( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1) + ), + "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))), } FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + "CHAR": lambda self: self._parse_chr(), "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -531,6 +579,18 @@ class MySQL(Dialect): return super()._parse_type(parse_interval=parse_interval) + def _parse_chr(self) -> t.Optional[exp.Expression]: + expressions = self._parse_csv(self._parse_conjunction) + kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)} + + if len(expressions) > 1: + kwargs["expressions"] = expressions[1:] + + if self._match(TokenType.USING): + kwargs["charset"] = self._parse_var() + + return self.expression(exp.Chr, **kwargs) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -544,25 +604,33 @@ class MySQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, - exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), - exp.DateAdd: _date_add_sql("ADD"), + exp.DateDiff: _remove_ts_or_ds_to_date( + lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") + ), + exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("SUB"), + exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")), exp.DateTrunc: _date_trunc_sql, - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.Day: _remove_ts_or_ds_to_date(), + exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), + exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), + exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.Month: _remove_ts_or_ds_to_date(), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] ), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -573,10 +641,16 @@ class MySQL(Dialect): exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), - exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), + exp.TimeToStr: _remove_ts_or_ds_to_date( + lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)) + ), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, - exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.Week: _remove_ts_or_ds_to_date(), + exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), + exp.Year: _remove_ts_or_ds_to_date(), } UNSIGNED_TYPE_MAPPING = { @@ -585,6 +659,7 @@ class MySQL(Dialect): exp.DataType.Type.UMEDIUMINT: "MEDIUMINT", exp.DataType.Type.USMALLINT: "SMALLINT", exp.DataType.Type.UTINYINT: "TINYINT", + exp.DataType.Type.UDECIMAL: "DECIMAL", } TIMESTAMP_TYPE_MAPPING = { @@ -717,3 +792,9 @@ class MySQL(Dialect): limit_offset = f"{offset}, {limit}" if offset else limit return f" LIMIT {limit_offset}" return "" + + def chr_sql(self, expression: exp.Chr) -> str: + this = self.expressions(sqls=[expression.this] + expression.expressions) + charset = expression.args.get("charset") + using = f" USING {self.sql(charset)}" if charset else "" + return f"CHAR({this}{using})" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 0a4926d..6a007ab 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -153,6 +153,7 @@ class Oracle(Dialect): JOIN_HINTS = False TABLE_HINTS = False COLUMN_JOIN_MARKS_SUPPORTED = True + DATA_TYPE_SPECIFIERS_ALLOWED = True LIMIT_FETCH = "FETCH" @@ -179,7 +180,12 @@ class Oracle(Dialect): ), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.ILike: no_ilike_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_distinct_on, + transforms.eliminate_qualify, + ] + ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 342fd95..008727c 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -22,6 +22,7 @@ from sqlglot.dialects.dialect import ( rename_func, simplify_literal, str_position_sql, + struct_extract_sql, timestamptrunc_sql, timestrtotime_sql, trim_sql, @@ -248,11 +249,10 @@ class Postgres(Dialect): } class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", "$$"] - BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] + HEREDOC_STRINGS = ["$"] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -296,7 +296,7 @@ class Postgres(Dialect): SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, + "$": TokenType.HEREDOC_STRING, } VAR_SINGLE_TOKENS = {"$"} @@ -420,9 +420,15 @@ class Postgres(Dialect): exp.Pow: lambda self, e: self.binary(e, "^"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), - exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] + ), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 0d8d4ab..e5cfa1c 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -309,6 +309,9 @@ class Presto(Dialect): exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), + exp.GroupConcat: lambda self, e: self.func( + "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator") + ), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql(), exp.ILike: no_ilike_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 2145844..88e4448 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -83,7 +83,7 @@ class Redshift(Postgres): class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] - STRING_ESCAPES = ["\\"] + STRING_ESCAPES = ["\\", "'"] KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5c49331..fc3e0fa 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -239,6 +239,8 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW} + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, @@ -318,6 +320,43 @@ class Snowflake(Dialect): "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), } + STAGED_FILE_SINGLE_TOKENS = { + TokenType.DOT, + TokenType.MOD, + TokenType.SLASH, + } + + def _parse_table_parts(self, schema: bool = False) -> exp.Table: + # https://docs.snowflake.com/en/user-guide/querying-stage + table: t.Optional[exp.Expression] = None + if self._match_text_seq("@"): + table_name = "@" + while True: + self._advance() + table_name += self._prev.text + if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False): + break + while self._match_set(self.STAGED_FILE_SINGLE_TOKENS): + table_name += self._prev.text + + table = exp.var(table_name) + elif self._match(TokenType.STRING, advance=False): + table = self._parse_string() + + if table: + file_format = None + pattern = None + + if self._match_text_seq("(", "FILE_FORMAT", "=>"): + file_format = self._parse_string() or super()._parse_table_parts() + if self._match_text_seq(",", "PATTERN", "=>"): + pattern = self._parse_string() + self._match_r_paren() + + return self.expression(exp.Table, this=table, format=file_format, pattern=pattern) + + return super()._parse_table_parts(schema=schema) + def _parse_id_var( self, any_token: bool = True, @@ -394,6 +433,8 @@ class Snowflake(Dialect): TABLE_HINTS = False QUERY_HINTS = False AGGREGATE_FILTER_SUPPORTED = False + SUPPORTS_TABLE_COPY = False + COLLATE_IS_FUNC = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -423,6 +464,12 @@ class Snowflake(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.PercentileCont: transforms.preprocess( + [transforms.add_within_group_for_percentiles] + ), + exp.PercentileDisc: transforms.preprocess( + [transforms.add_within_group_for_percentiles] + ), exp.RegexpILike: _regexpilike_sql, exp.Select: transforms.preprocess( [ diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 9d4a1ab..2eaa2ae 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -54,6 +54,14 @@ class Spark(Spark2): FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("ANY_VALUE") + def _parse_generated_as_identity( + self, + ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint: + this = super()._parse_generated_as_identity() + if this.expression: + return self.expression(exp.ComputedColumnConstraint, this=this.expression) + return this + class Generator(Spark2.Generator): TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, @@ -73,6 +81,9 @@ class Spark(Spark2): TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) + def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: + return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})" + def anyvalue_sql(self, expression: exp.AnyValue) -> str: return self.function_fallback_sql(expression) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 3dc9838..4130375 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( binary_from_function, - create_with_partitions_sql, format_time_lambda, is_parse_json, move_insert_cte_sql, @@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self: Spark2.Generator, e: exp.Create) -> str: - kind = e.args["kind"] - properties = e.args.get("properties") - - if ( - kind.upper() == "TABLE" - and e.expression - and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - ): - return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return create_with_partitions_sql(self, e) - - def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: keys = expression.args.get("keys") values = expression.args.get("values") @@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: class Spark2(Hive): class Parser(Hive.Parser): + TRIM_PATTERN_FIRST = True + FUNCTIONS = { **Hive.Parser.FUNCTIONS, "AGGREGATE": exp.Reduce.from_arg_list, @@ -192,7 +177,6 @@ class Spark2(Hive): exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.Create: _create_sql, exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -236,6 +220,12 @@ class Spark2(Hive): WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index fa62e78..6aa49e4 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import ( parse_date_delta, rename_func, timestrtotime_sql, + ts_or_ds_to_date_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -590,6 +591,7 @@ class TSQL(Dialect): NVL2_SUPPORTED = False ALTER_TABLE_ADD_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" + COMPUTED_COLUMN_WITH_TYPE = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -619,7 +621,11 @@ class TSQL(Dialect): exp.Min: min_or_least, exp.NumberToStr: _format_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] ), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( @@ -630,6 +636,7 @@ class TSQL(Dialect): exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } TRANSFORMS.pop(exp.ReturnsProperty) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 9f63100..bf2941c 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -202,4 +202,5 @@ ENV = { "CURRENTTIME": datetime.datetime.now, "CURRENTDATE": datetime.date.today, "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)), + "TRIM": null_if_any(lambda this, e=None: this.strip(e)), } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8e9575e..1e4aad6 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -52,6 +52,9 @@ class _Expression(type): return klass +SQLGLOT_META = "sqlglot.meta" + + class Expression(metaclass=_Expression): """ The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary @@ -266,7 +269,14 @@ class Expression(metaclass=_Expression): if self.comments is None: self.comments = [] if comments: - self.comments.extend(comments) + for comment in comments: + _, *meta = comment.split(SQLGLOT_META) + if meta: + for kv in "".join(meta).split(","): + k, *v = kv.split("=") + value = v[0].strip() if v else True + self.meta[k.strip()] = value + self.comments.append(comment) def append(self, arg_key: str, value: t.Any) -> None: """ @@ -1036,11 +1046,14 @@ class Create(DDL): "indexes": False, "no_schema_binding": False, "begin": False, + "end": False, "clone": False, } # https://docs.snowflake.com/en/sql-reference/sql/create-clone +# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement +# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy class Clone(Expression): arg_types = { "this": True, @@ -1048,6 +1061,7 @@ class Clone(Expression): "kind": False, "shallow": False, "expression": False, + "copy": False, } @@ -1610,6 +1624,11 @@ class Identifier(Expression): return self.name +# https://www.postgresql.org/docs/current/indexes-opclass.html +class Opclass(Expression): + arg_types = {"this": True, "expression": True} + + class Index(Expression): arg_types = { "this": False, @@ -2156,6 +2175,10 @@ class QueryTransform(Expression): } +class SampleProperty(Property): + arg_types = {"this": True} + + class SchemaCommentProperty(Property): arg_types = {"this": True} @@ -2440,6 +2463,8 @@ class Table(Expression): "hints": False, "system_time": False, "version": False, + "format": False, + "pattern": False, } @property @@ -2465,17 +2490,17 @@ class Table(Expression): return [] @property - def parts(self) -> t.List[Identifier]: + def parts(self) -> t.List[Expression]: """Return the parts of a table in order catalog, db, table.""" - parts: t.List[Identifier] = [] + parts: t.List[Expression] = [] for arg in ("catalog", "db", "this"): part = self.args.get(arg) - if isinstance(part, Identifier): - parts.append(part) - elif isinstance(part, Dot): + if isinstance(part, Dot): parts.extend(part.flatten()) + elif isinstance(part, Expression): + parts.append(part) return parts @@ -2910,6 +2935,7 @@ class Select(Subqueryable): prefix="OFFSET", dialect=dialect, copy=copy, + into_arg="expression", **opts, ) @@ -3572,6 +3598,7 @@ class DataType(Expression): UINT128 = auto() UINT256 = auto() UMEDIUMINT = auto() + UDECIMAL = auto() UNIQUEIDENTIFIER = auto() UNKNOWN = auto() # Sentinel value, useful for type annotation USERDEFINED = "USER-DEFINED" @@ -3693,13 +3720,13 @@ class DataType(Expression): # https://www.postgresql.org/docs/15/datatype-pseudo.html -class PseudoType(Expression): - pass +class PseudoType(DataType): + arg_types = {"this": True} # https://www.postgresql.org/docs/15/datatype-oid.html -class ObjectIdentifier(Expression): - pass +class ObjectIdentifier(DataType): + arg_types = {"this": True} # WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...) @@ -4027,10 +4054,20 @@ class TimeUnit(Expression): return self.args.get("unit") +class IntervalOp(TimeUnit): + arg_types = {"unit": True, "expression": True} + + def interval(self): + return Interval( + this=self.expression.copy(), + unit=self.unit.copy(), + ) + + # https://www.oracletutorial.com/oracle-basics/oracle-interval/ # https://trino.io/docs/current/language/types.html#interval-day-to-second # https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html -class IntervalSpan(Expression): +class IntervalSpan(DataType): arg_types = {"this": True, "expression": True} @@ -4269,7 +4306,7 @@ class CastToStrType(Func): arg_types = {"this": True, "to": True} -class Collate(Binary): +class Collate(Binary, Func): pass @@ -4284,6 +4321,12 @@ class Coalesce(Func): _sql_names = ["COALESCE", "IFNULL", "NVL"] +class Chr(Func): + arg_types = {"this": True, "charset": False, "expressions": False} + is_var_len_args = True + _sql_names = ["CHR", "CHAR"] + + class Concat(Func): arg_types = {"expressions": True} is_var_len_args = True @@ -4326,11 +4369,11 @@ class CurrentUser(Func): arg_types = {"this": False} -class DateAdd(Func, TimeUnit): +class DateAdd(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} -class DateSub(Func, TimeUnit): +class DateSub(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} @@ -4347,11 +4390,11 @@ class DateTrunc(Func): return self.args["unit"] -class DatetimeAdd(Func, TimeUnit): +class DatetimeAdd(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} -class DatetimeSub(Func, TimeUnit): +class DatetimeSub(Func, IntervalOp): arg_types = {"this": True, "expression": True, "unit": False} @@ -4375,6 +4418,10 @@ class DayOfYear(Func): _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"] +class ToDays(Func): + pass + + class WeekOfYear(Func): _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] @@ -6160,7 +6207,7 @@ def table_name(table: Table | str, dialect: DialectType = None) -> str: The table name. """ - table = maybe_parse(table, into=Table) + table = maybe_parse(table, into=Table, dialect=dialect) if not table: raise ValueError(f"Cannot parse {table}") diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b1ee783..edc6939 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -86,6 +86,7 @@ class Generator: exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", exp.ReturnsProperty: lambda self, e: self.naked_property(e), + exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", @@ -204,6 +205,21 @@ class Generator: # Whether or not session variables / parameters are supported, e.g. @x in T-SQL SUPPORTS_PARAMETERS = True + # Whether or not to include the type of a computed column in the CREATE DDL + COMPUTED_COLUMN_WITH_TYPE = True + + # Whether or not CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY + SUPPORTS_TABLE_COPY = True + + # Whether or not parentheses are required around the table sample's expression + TABLESAMPLE_REQUIRES_PARENS = True + + # Whether or not COLLATE is a function instead of a binary operator + COLLATE_IS_FUNC = False + + # Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle) + DATA_TYPE_SPECIFIERS_ALLOWED = False + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -282,6 +298,7 @@ class Generator: exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, + exp.SampleProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, exp.Set: exp.Properties.Location.POST_SCHEMA, @@ -324,13 +341,12 @@ class Generator: exp.Paren, ) - UNESCAPED_SEQUENCE_TABLE = None # type: ignore - SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" # Autofilled INVERSE_TIME_MAPPING: t.Dict[str, str] = {} INVERSE_TIME_TRIE: t.Dict = {} + INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} INDEX_OFFSET = 0 UNNEST_COLUMN_ONLY = False ALIAS_POST_TABLESAMPLE = False @@ -480,8 +496,7 @@ class Generator: if not comments or isinstance(expression, exp.Binary): return sql - sep = "\n" if self.pretty else " " - comments_sql = sep.join( + comments_sql = " ".join( f"/*{self.pad_comment(comment)}*/" for comment in comments if comment ) @@ -649,6 +664,9 @@ class Generator: position = self.sql(expression, "position") position = f" {position}" if position else "" + if expression.find(exp.ComputedColumnConstraint) and not self.COMPUTED_COLUMN_WITH_TYPE: + kind = "" + return f"{exists}{column}{kind}{constraints}{position}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: @@ -750,9 +768,11 @@ class Generator: ) begin = " BEGIN" if expression.args.get("begin") else "" + end = " END" if expression.args.get("end") else "" + expression_sql = self.sql(expression, "expression") if expression_sql: - expression_sql = f"{begin}{self.sep()}{expression_sql}" + expression_sql = f"{begin}{self.sep()}{expression_sql}{end}" if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return): if properties_locs.get(exp.Properties.Location.POST_ALIAS): @@ -817,7 +837,8 @@ class Generator: def clone_sql(self, expression: exp.Clone) -> str: this = self.sql(expression, "this") shallow = "SHALLOW " if expression.args.get("shallow") else "" - this = f"{shallow}CLONE {this}" + keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE" + this = f"{shallow}{keyword} {this}" when = self.sql(expression, "when") if when: @@ -877,7 +898,7 @@ class Generator: def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: this = self.sql(expression, "this") specifier = self.sql(expression, "expression") - specifier = f" {specifier}" if specifier else "" + specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else "" return f"{this}{specifier}" def datatype_sql(self, expression: exp.DataType) -> str: @@ -1329,8 +1350,13 @@ class Generator: pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="", skip_first=True) laterals = self.expressions(expression, key="laterals", sep="") + file_format = self.sql(expression, "format") + if file_format: + pattern = self.sql(expression, "pattern") + pattern = f", PATTERN => {pattern}" if pattern else "" + file_format = f" (FILE_FORMAT => {file_format}{pattern})" - return f"{table}{version}{alias}{hints}{pivots}{joins}{laterals}" + return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " @@ -1343,6 +1369,7 @@ class Generator: else: this = self.sql(expression, "this") alias = "" + method = self.sql(expression, "method") method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else "" numerator = self.sql(expression, "bucket_numerator") @@ -1354,13 +1381,20 @@ class Generator: percent = f"{percent} PERCENT" if percent else "" rows = self.sql(expression, "rows") rows = f"{rows} ROWS" if rows else "" + size = self.sql(expression, "size") if size and self.TABLESAMPLE_SIZE_IS_PERCENT: size = f"{size} PERCENT" + seed = self.sql(expression, "seed") seed = f" {seed_prefix} ({seed})" if seed else "" kind = expression.args.get("kind", "TABLESAMPLE") - return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}" + + expr = f"{bucket}{percent}{rows}{size}" + if self.TABLESAMPLE_REQUIRES_PARENS: + expr = f"({expr})" + + return f"{this} {kind} {method}{expr}{seed}{alias}" def pivot_sql(self, expression: exp.Pivot) -> str: expressions = self.expressions(expression, flat=True) @@ -1638,8 +1672,8 @@ class Generator: def escape_str(self, text: str) -> str: text = text.replace(self.QUOTE_END, self._escaped_quote_end) - if self.UNESCAPED_SEQUENCE_TABLE: - text = text.translate(self.UNESCAPED_SEQUENCE_TABLE) + if self.INVERSE_ESCAPE_SEQUENCES: + text = "".join(self.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text) elif self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) return text @@ -2301,6 +2335,8 @@ class Generator: return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" def collate_sql(self, expression: exp.Collate) -> str: + if self.COLLATE_IS_FUNC: + return self.function_fallback_sql(expression) return self.binary(expression, "COLLATE") def command_sql(self, expression: exp.Command) -> str: @@ -2359,7 +2395,7 @@ class Generator: collate = f" COLLATE {collate}" if collate else "" using = self.sql(expression, "using") using = f" USING {using}" if using else "" - return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}" + return f"ALTER COLUMN {this} SET DATA TYPE {dtype}{collate}{using}" default = self.sql(expression, "default") if default: @@ -2396,7 +2432,7 @@ class Generator: elif isinstance(actions[0], exp.Delete): actions = self.expressions(expression, key="actions", flat=True) else: - actions = self.expressions(expression, key="actions") + actions = self.expressions(expression, key="actions", flat=True) exists = " IF EXISTS" if expression.args.get("exists") else "" only = " ONLY" if expression.args.get("only") else "" @@ -2593,7 +2629,7 @@ class Generator: self, expression: t.Optional[exp.Expression] = None, key: t.Optional[str] = None, - sqls: t.Optional[t.List[str]] = None, + sqls: t.Optional[t.Collection[str | exp.Expression]] = None, flat: bool = False, indent: bool = True, skip_first: bool = False, @@ -2841,6 +2877,9 @@ class Generator: def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str: return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})" + def opclass_sql(self, expression: exp.Opclass) -> str: + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index afc6995..17af6ac 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,7 @@ from __future__ import annotations +import datetime +import functools import typing as t from sqlglot import exp @@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema if t.TYPE_CHECKING: B = t.TypeVar("B", bound=exp.Binary) + BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] + BinaryCoercions = t.Dict[ + t.Tuple[exp.DataType.Type, exp.DataType.Type], + BinaryCoercionFunc, + ] + + +# Interval units that operate on date components +DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} + def annotate_types( expression: E, @@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type return lambda self, e: self._annotate_with_type(e, data_type) +def _is_iso_date(text: str) -> bool: + try: + datetime.date.fromisoformat(text) + return True + except ValueError: + return False + + +def _is_iso_datetime(text: str) -> bool: + try: + datetime.datetime.fromisoformat(text) + return True + except ValueError: + return False + + +def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + date_text = l.name + unit = r.text("unit").lower() + + is_iso_date = _is_iso_date(date_text) + + if is_iso_date and unit in DATE_UNITS: + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) + return exp.DataType.Type.DATE + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date or _is_iso_datetime(date_text): + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) + return exp.DataType.Type.DATETIME + + return exp.DataType.Type.UNKNOWN + + +def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + unit = r.text("unit").lower() + if unit not in DATE_UNITS: + return exp.DataType.Type.DATETIME + return l.type.this if l.type else exp.DataType.Type.UNKNOWN + + +def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: + @functools.wraps(func) + def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + return func(r, l) + + return _swapped + + +def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: + return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} + + class _TypeAnnotator(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DataType.Type.DATE: { exp.CurrentDate, exp.Date, - exp.DateAdd, exp.DateFromParts, exp.DateStrToDate, - exp.DateSub, exp.DateTrunc, exp.DiToDate, exp.StrToDate, @@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.DateAdd: lambda self, e: self._annotate_dateadd(e), + exp.DateSub: lambda self, e: self._annotate_dateadd(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), @@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator): # Specifies what types a given type can be coerced into (autofilled) COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + # Coercion functions for binary operations. + # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. + BINARY_COERCIONS: BinaryCoercions = { + **swap_all( + { + (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval + for t in exp.DataType.TEXT_TYPES + } + ), + **swap_all( + { + (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, + } + ), + } + def __init__( self, schema: Schema, annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + binary_coercions: t.Optional[BinaryCoercions] = None, ) -> None: self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO + self.binary_coercions = binary_coercions or self.BINARY_COERCIONS # Caches the ids of annotated sub-Expressions, to ensure we only visit them once self._visited: t.Set[int] = set() - def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: - expression.type = target_type + def _set_type( + self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type + ) -> None: + expression.type = target_type # type: ignore self._visited.add(id(expression)) def annotate(self, expression: E) -> E: @@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) - left_type = expression.left.type.this - right_type = expression.right.type.this + left, right = expression.left, expression.right + left_type, right_type = left.type.this, right.type.this if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, exp.DataType.Type.BOOLEAN) elif isinstance(expression, exp.Predicate): self._set_type(expression, exp.DataType.Type.BOOLEAN) + elif (left_type, right_type) in self.binary_coercions: + self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) else: self._set_type(expression, self._maybe_coerce(left_type, right_type)) @@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ) return expression + + def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: + self._annotate_args(expression) + + if expression.this.type.this in exp.DataType.TEXT_TYPES: + datatype = _coerce_literal_and_interval(expression.this, expression.interval()) + elif ( + expression.this.type.is_type(exp.DataType.Type.DATE) + and expression.text("unit").lower() not in DATE_UNITS + ): + datatype = exp.DataType.Type.DATETIME + else: + datatype = expression.this.type + + self._set_type(expression, datatype) + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index e45d1e3..ec3b3af 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -45,9 +45,11 @@ def coerce_type(node: exp.Expression) -> exp.Expression: _coerce_date(node.left, node.right) elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) - elif isinstance(node, exp.Extract): - if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES: - _replace_cast(node.expression, "datetime") + elif isinstance(node, exp.Extract) and not node.expression.type.is_type( + *exp.DataType.TEMPORAL_TYPES + ): + _replace_cast(node.expression, exp.DataType.Type.DATETIME) + return node @@ -67,7 +69,7 @@ def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: _replace_int_predicate(expression.left) _replace_int_predicate(expression.right) - elif isinstance(expression, (exp.Where, exp.Having)): + elif isinstance(expression, (exp.Where, exp.Having, exp.If)): _replace_int_predicate(expression.this) return expression @@ -89,13 +91,16 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: and b.type and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) ): - _replace_cast(b, "date") + _replace_cast(b, exp.DataType.Type.DATE) -def _replace_cast(node: exp.Expression, to: str) -> None: +def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: node.replace(exp.cast(node.copy(), to=to)) def _replace_int_predicate(expression: exp.Expression) -> None: - if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: + if isinstance(expression, exp.Coalesce): + for _, child in expression.iter_expressions(): + _replace_int_predicate(child) + elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 976c9ad..b0b2b3d 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -181,7 +181,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not outer_scope.pivots - and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) + and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) and not ( isinstance(from_or_join, exp.Join) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 54cf02b..32f3a92 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -22,6 +22,13 @@ def normalize_identifiers(expression, dialect=None): Normalize all unquoted identifiers to either lower or upper case, depending on the dialect. This essentially makes those identifiers case-insensitive. + It's possible to make this a no-op by adding a special comment next to the + identifier of interest: + + SELECT a /* sqlglot.meta case_sensitive */ FROM table + + In this example, the identifier `a` will not be normalized. + Note: Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even when they're quoted, so in these cases all identifiers are normalized. @@ -43,4 +50,13 @@ def normalize_identifiers(expression, dialect=None): """ if isinstance(expression, str): expression = exp.to_identifier(expression) - return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False) + + dialect = Dialect.get_or_raise(dialect) + + def _normalize(node: E) -> E: + if not node.meta.get("case_sensitive"): + exp.replace_children(node, _normalize) + node = dialect.normalize_identifier(node) + return node + + return _normalize(expression) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d08c692..51214c4 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool: return expression.is_number -def _is_date(expression: exp.Expression) -> bool: - return isinstance(expression, exp.Cast) and extract_date(expression) is not None - - def _is_interval(expression: exp.Expression) -> bool: return isinstance(expression, exp.Interval) and extract_interval(expression) is not None @@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if r.is_number: a_predicate = _is_number b_predicate = _is_number - elif _is_date(r): - a_predicate = _is_date + elif _is_date_literal(r): + a_predicate = _is_date_literal b_predicate = _is_interval else: return expression if l.__class__ in INVERSE_DATE_OPS: a = l.this - b = exp.Interval( - this=l.expression.copy(), - unit=l.unit.copy(), - ) + b = l.interval() else: a, b = l.left, l.right @@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b): if boolean: return boolean - elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): + elif _is_date_literal(a) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) if a and b: if isinstance(expression, exp.Add): return date_literal(a + b) if isinstance(expression, exp.Sub): return date_literal(a - b) - elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): + elif isinstance(a, exp.Interval) and _is_date_literal(b): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): @@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return ( - isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) - and isinstance(right, exp.Cast) - and right.is_type(*exp.DataType.TEMPORAL_TYPES) - ) + return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) @catch(ModuleNotFoundError, UnsupportedUnit) @@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: unit = l.unit.name.lower() date = extract_date(r) + if not date: + return expression + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression elif isinstance(expression, exp.In): l = expression.this rs = expression.expressions - if all(_is_datetrunc_predicate(l, r) for r in rs): + if rs and all(_is_datetrunc_predicate(l, r) for r in rs): unit = l.unit.name.lower() - ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r] + ranges = [] + for r in rs: + date = extract_date(r) + if not date: + return expression + drange = _datetrunc_range(date, unit) + if drange: + ranges.append(drange) + if not ranges: return expression @@ -811,18 +811,59 @@ def eval_boolean(expression, a, b): return None -def extract_date(cast): - # The "fromisoformat" conversion could fail if the cast is used on an identifier, - # so in that case we can't extract the date. +def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: + if isinstance(value, datetime.datetime): + return value.date() + if isinstance(value, datetime.date): + return value try: - if cast.args["to"].this == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(cast.name) - if cast.args["to"].this == exp.DataType.Type.DATETIME: - return datetime.datetime.fromisoformat(cast.name) + return datetime.datetime.fromisoformat(value).date() except ValueError: return None +def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + return datetime.datetime(year=value.year, month=value.month, day=value.day) + try: + return datetime.datetime.fromisoformat(value) + except ValueError: + return None + + +def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: + if not value: + return None + if to.is_type(exp.DataType.Type.DATE): + return cast_as_date(value) + if to.is_type(*exp.DataType.TEMPORAL_TYPES): + return cast_as_datetime(value) + return None + + +def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: + if isinstance(cast, exp.Cast): + to = cast.to + elif isinstance(cast, exp.TsOrDsToDate): + to = exp.DataType.build(exp.DataType.Type.DATE) + else: + return None + + if isinstance(cast.this, exp.Literal): + value: t.Any = cast.this.name + elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): + value = extract_date(cast.this) + else: + return None + return cast_value(value, to) + + +def _is_date_literal(expression: exp.Expression) -> bool: + return extract_date(expression) is not None + + def extract_interval(expression): n = int(expression.name) unit = expression.text("unit").lower() @@ -836,7 +877,9 @@ def extract_interval(expression): def date_literal(date): return exp.cast( exp.Literal.string(date), - "DATETIME" if isinstance(date, datetime.datetime) else "DATE", + exp.DataType.Type.DATETIME + if isinstance(date, datetime.datetime) + else exp.DataType.Type.DATE, ) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 84b2639..5e56961 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -178,6 +178,7 @@ class Parser(metaclass=_Parser): TokenType.DATERANGE, TokenType.DATEMULTIRANGE, TokenType.DECIMAL, + TokenType.UDECIMAL, TokenType.BIGDECIMAL, TokenType.UUID, TokenType.GEOGRAPHY, @@ -215,6 +216,7 @@ class Parser(metaclass=_Parser): TokenType.MEDIUMINT: TokenType.UMEDIUMINT, TokenType.SMALLINT: TokenType.USMALLINT, TokenType.TINYINT: TokenType.UTINYINT, + TokenType.DECIMAL: TokenType.UDECIMAL, } SUBQUERY_PREDICATES = { @@ -338,6 +340,7 @@ class Parser(metaclass=_Parser): TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"} FUNC_TOKENS = { + TokenType.COLLATE, TokenType.COMMAND, TokenType.CURRENT_DATE, TokenType.CURRENT_DATETIME, @@ -590,6 +593,9 @@ class Parser(metaclass=_Parser): exp.National, this=token.text ), TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text), + TokenType.HEREDOC_STRING: lambda self, token: self.expression( + exp.RawString, this=token.text + ), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -666,6 +672,9 @@ class Parser(metaclass=_Parser): "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), + "SAMPLE": lambda self: self.expression( + exp.SampleProperty, this=self._match_text_seq("BY") and self._parse_bitwise() + ), "SET": lambda self: self.expression(exp.SetProperty, multi=False), "SETTINGS": lambda self: self.expression( exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item) @@ -847,8 +856,11 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} + CLONE_KEYWORDS = {"CLONE", "COPY"} CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"} + OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"} + TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE} WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} @@ -863,6 +875,8 @@ class Parser(metaclass=_Parser): NULL_TOKENS = {TokenType.NULL} + UNNEST_OFFSET_ALIAS_TOKENS = ID_VAR_TOKENS - SET_OPERATIONS + STRICT_CAST = True # A NULL arg in CONCAT yields NULL by default @@ -880,9 +894,12 @@ class Parser(metaclass=_Parser): # Whether or not the table sample clause expects CSV syntax TABLESAMPLE_CSV = False - # Whether or not the SET command needs a delimiter (e.g. "=") for assignments. + # Whether or not the SET command needs a delimiter (e.g. "=") for assignments SET_REQUIRES_ASSIGNMENT_DELIMITER = True + # Whether the TRIM function expects the characters to trim as its first argument + TRIM_PATTERN_FIRST = False + __slots__ = ( "error_level", "error_message_context", @@ -1268,6 +1285,7 @@ class Parser(metaclass=_Parser): indexes = None no_schema_binding = None begin = None + end = None clone = None def extend_props(temp_props: t.Optional[exp.Properties]) -> None: @@ -1299,6 +1317,8 @@ class Parser(metaclass=_Parser): else: expression = self._parse_statement() + end = self._match_text_seq("END") + if return_: expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: @@ -1344,7 +1364,8 @@ class Parser(metaclass=_Parser): shallow = self._match_text_seq("SHALLOW") - if self._match_text_seq("CLONE"): + if self._match_texts(self.CLONE_KEYWORDS): + copy = self._prev.text.lower() == "copy" clone = self._parse_table(schema=True) when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper() clone_kind = ( @@ -1361,6 +1382,7 @@ class Parser(metaclass=_Parser): kind=clone_kind, shallow=shallow, expression=clone_expression, + copy=copy, ) return self.expression( @@ -1376,6 +1398,7 @@ class Parser(metaclass=_Parser): indexes=indexes, no_schema_binding=no_schema_binding, begin=begin, + end=end, clone=clone, ) @@ -2445,21 +2468,32 @@ class Parser(metaclass=_Parser): kwargs["using"] = self._parse_wrapped_id_vars() elif not (kind and kind.token_type == TokenType.CROSS): index = self._index - joins = self._parse_joins() + join = self._parse_join() - if joins and self._match(TokenType.ON): + if join and self._match(TokenType.ON): kwargs["on"] = self._parse_conjunction() - elif joins and self._match(TokenType.USING): + elif join and self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() else: - joins = None + join = None self._retreat(index) - kwargs["this"].set("joins", joins) + kwargs["this"].set("joins", [join] if join else None) comments = [c for token in (method, side, kind) if token for c in token.comments] return self.expression(exp.Join, comments=comments, **kwargs) + def _parse_opclass(self) -> t.Optional[exp.Expression]: + this = self._parse_conjunction() + if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False): + return this + + opclass = self._parse_var(any_token=True) + if opclass: + return self.expression(exp.Opclass, this=this, expression=opclass) + + return this + def _parse_index( self, index: t.Optional[exp.Expression] = None, @@ -2486,7 +2520,7 @@ class Parser(metaclass=_Parser): using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None if self._match(TokenType.L_PAREN, advance=False): - columns = self._parse_wrapped_csv(self._parse_ordered) + columns = self._parse_wrapped_csv(lambda: self._parse_ordered(self._parse_opclass)) else: columns = None @@ -2677,7 +2711,9 @@ class Parser(metaclass=_Parser): if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET): self._match(TokenType.ALIAS) - offset = self._parse_id_var() or exp.to_identifier("offset") + offset = self._parse_id_var( + any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS + ) or exp.to_identifier("offset") return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset) @@ -2715,14 +2751,18 @@ class Parser(metaclass=_Parser): ) method = self._parse_var(tokens=(TokenType.ROW,)) - self._match(TokenType.L_PAREN) + matched_l_paren = self._match(TokenType.L_PAREN) if self.TABLESAMPLE_CSV: num = None expressions = self._parse_csv(self._parse_primary) else: expressions = None - num = self._parse_primary() + num = ( + self._parse_factor() + if self._match(TokenType.NUMBER, advance=False) + else self._parse_primary() + ) if self._match_text_seq("BUCKET"): bucket_numerator = self._parse_number() @@ -2737,7 +2777,8 @@ class Parser(metaclass=_Parser): elif num: size = num - self._match(TokenType.R_PAREN) + if matched_l_paren: + self._match_r_paren() if self._match(TokenType.L_PAREN): method = self._parse_var() @@ -2965,8 +3006,8 @@ class Parser(metaclass=_Parser): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self) -> exp.Ordered: - this = self._parse_conjunction() + def _parse_ordered(self, parse_method: t.Optional[t.Callable] = None) -> exp.Ordered: + this = parse_method() if parse_method else self._parse_conjunction() asc = self._match(TokenType.ASC) desc = self._match(TokenType.DESC) or (asc and False) @@ -3144,7 +3185,7 @@ class Parser(metaclass=_Parser): if self._match_text_seq("DISTINCT", "FROM"): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ - return self.expression(klass, this=this, expression=self._parse_expression()) + return self.expression(klass, this=this, expression=self._parse_conjunction()) expression = self._parse_null() or self._parse_boolean() if not expression: @@ -3760,7 +3801,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) - def _parse_generated_as_identity(self) -> exp.GeneratedAsIdentityColumnConstraint: + def _parse_generated_as_identity( + self, + ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint: if self._match_text_seq("BY", "DEFAULT"): on_null = self._match_pair(TokenType.ON, TokenType.NULL) this = self.expression( @@ -4382,16 +4425,18 @@ class Parser(metaclass=_Parser): position = None collation = None + expression = None if self._match_texts(self.TRIM_TYPES): position = self._prev.text.upper() - expression = self._parse_bitwise() + this = self._parse_bitwise() if self._match_set((TokenType.FROM, TokenType.COMMA)): - this = self._parse_bitwise() - else: - this = expression - expression = None + invert_order = self._prev.token_type == TokenType.FROM or self.TRIM_PATTERN_FIRST + expression = self._parse_bitwise() + + if invert_order: + this, expression = expression, this if self._match(TokenType.COLLATE): collation = self._parse_bitwise() diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 4d5f198..080a86b 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -77,6 +77,7 @@ class TokenType(AutoName): BYTE_STRING = auto() NATIONAL_STRING = auto() RAW_STRING = auto() + HEREDOC_STRING = auto() # types BIT = auto() @@ -98,6 +99,7 @@ class TokenType(AutoName): FLOAT = auto() DOUBLE = auto() DECIMAL = auto() + UDECIMAL = auto() BIGDECIMAL = auto() CHAR = auto() NCHAR = auto() @@ -418,6 +420,7 @@ class _Tokenizer(type): **_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS), **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS), **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS), + **_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS), } klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) @@ -484,11 +487,13 @@ class Tokenizer(metaclass=_Tokenizer): BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] RAW_STRINGS: t.List[str | t.Tuple[str, str]] = [] + HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = [] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] IDENTIFIER_ESCAPES = ['"'] QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] STRING_ESCAPES = ["'"] VAR_SINGLE_TOKENS: t.Set[str] = set() + ESCAPE_SEQUENCES: t.Dict[str, str] = {} # Autofilled IDENTIFIERS_CAN_START_WITH_DIGIT: bool = False @@ -997,9 +1002,11 @@ class Tokenizer(metaclass=_Tokenizer): word = word.upper() self._add(self.KEYWORDS[word], text=word) return + if self._char in self.SINGLE_TOKENS: self._add(self.SINGLE_TOKENS[self._char], text=self._char) return + self._scan_var() def _scan_comment(self, comment_start: str) -> bool: @@ -1126,6 +1133,10 @@ class Tokenizer(metaclass=_Tokenizer): base = 16 elif token_type == TokenType.BIT_STRING: base = 2 + elif token_type == TokenType.HEREDOC_STRING: + self._advance() + tag = "" if self._char == end else self._extract_string(end) + end = f"{start}{tag}{end}" else: return False @@ -1193,6 +1204,13 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") + if self.ESCAPE_SEQUENCES and self._peek and self._char in self.STRING_ESCAPES: + escaped_sequence = self.ESCAPE_SEQUENCES.get(self._char + self._peek) + if escaped_sequence: + self._advance(2) + text += escaped_sequence + continue + current = self._current - 1 self._advance(alnum=True) text += self.sql[current : self._current - 1] |