From 766db5014d053a8aecf75d550c2a1b59022bcabf Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 3 Feb 2023 07:02:50 +0100 Subject: Merging upstream version 10.6.0. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/column.py | 4 +- sqlglot/dataframe/sql/functions.py | 2 +- sqlglot/dialects/__init__.py | 15 +- sqlglot/dialects/bigquery.py | 7 + sqlglot/dialects/duckdb.py | 9 +- sqlglot/dialects/hive.py | 2 +- sqlglot/dialects/mysql.py | 2 +- sqlglot/dialects/oracle.py | 10 +- sqlglot/dialects/postgres.py | 14 + sqlglot/dialects/presto.py | 2 +- sqlglot/dialects/redshift.py | 2 +- sqlglot/dialects/snowflake.py | 6 +- sqlglot/dialects/spark.py | 3 +- sqlglot/expressions.py | 85 +++- sqlglot/generator.py | 248 ++++++++++-- sqlglot/optimizer/expand_multi_table_selects.py | 8 + sqlglot/optimizer/isolate_table_selects.py | 2 +- sqlglot/optimizer/optimize_joins.py | 5 + sqlglot/optimizer/optimizer.py | 1 - sqlglot/parser.py | 499 +++++++++++++++++++----- sqlglot/tokens.py | 38 +- sqlglot/transforms.py | 33 +- 23 files changed, 778 insertions(+), 221 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 67a4463..bfcabb3 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -33,7 +33,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema, Schema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.5.10" +__version__ = "10.6.0" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 22075e9..40ffe3e 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -94,10 +94,10 @@ class Column: return self.inverse_binary_op(exp.Mod, other) def __pow__(self, power: ColumnOrLiteral, modulo=None): - return Column(exp.Pow(this=self.expression, power=Column(power).expression)) + return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) def __rpow__(self, power: ColumnOrLiteral): - return Column(exp.Pow(this=Column(power).expression, power=self.expression)) + return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) def __invert__(self): return self.unary_op(exp.Not) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 1ee361a..a141fe4 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -311,7 +311,7 @@ def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float] def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_expression_over_column(col1, glotexp.Pow, power=col2) + return Column.invoke_expression_over_column(col1, glotexp.Pow, expression=col2) def row_number() -> Column: diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 34cf613..191e703 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -1,17 +1,14 @@ """ ## Dialects -One of the core abstractions in SQLGlot is the concept of a "dialect". The `Dialect` class essentially implements a -"SQLGlot dialect", which aims to be as generic and ANSI-compliant as possible. It relies on the base `Tokenizer`, -`Parser` and `Generator` classes to achieve this goal, so these need to be very lenient when it comes to consuming -SQL code. +While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult +to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible +SQL transpilation framework. -However, there are cases where the syntax of different SQL dialects varies wildly, even for common tasks. One such -example is the date/time functions, which can be hard to deal with. For this reason, it's sometimes necessary to -override the base dialect in order to specialize its behavior. This can be easily done in SQLGlot: supporting new -dialects is as simple as subclassing from `Dialect` and overriding its various components (e.g. the `Parser` class), -in order to implement the target behavior. +The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible. +Each SQL variation has its own `Dialect` subclass, extending the corresponding `Tokenizer`, `Parser` and `Generator` +classes as needed. ### Implementing a custom Dialect diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index e7d30ec..27dca48 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -169,6 +169,13 @@ class BigQuery(Dialect): TokenType.VALUES, } + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, # type: ignore + "NOT DETERMINISTIC": lambda self: self.expression( + exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + ), + } + class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 81941f7..4646eb4 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -66,12 +66,11 @@ def _sort_array_reverse(args): return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) -def _struct_pack_sql(self, expression): +def _struct_sql(self, expression): args = [ - self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) - for e in expression.expressions + f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions ] - return f"STRUCT_PACK({', '.join(args)})" + return f"{{{', '.join(args)}}}" def _datatype_sql(self, expression): @@ -153,7 +152,7 @@ class DuckDB(Dialect): exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", - exp.Struct: _struct_pack_sql, + exp.Struct: _struct_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ddfd1e8..4bbec70 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -251,7 +251,7 @@ class Hive(Dialect): PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, # type: ignore - TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties( + "WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties( expressions=self._parse_wrapped_csv(self._parse_property) ), } diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 2a0a917..cd8c30c 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -202,7 +202,7 @@ class MySQL(Dialect): PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, # type: ignore - TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), + "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), } STATEMENT_PARSERS = { diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 86caa6b..67d791d 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -74,13 +74,16 @@ class Oracle(Dialect): def query_modifiers(self, expression, *sqls): return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("laterals", [])], - *[self.sql(sql) for sql in expression.args.get("joins", [])], + *[self.sql(sql) for sql in expression.args.get("joins") or []], + self.sql(expression, "match"), + *[self.sql(sql) for sql in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), self.sql(expression, "qualify"), - self.sql(expression, "window"), + self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) + if expression.args.get("windows") + else "", self.sql(expression, "distribute"), self.sql(expression, "sort"), self.sql(expression, "cluster"), @@ -99,6 +102,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "START": TokenType.BEGIN, "TOP": TokenType.TOP, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 6f597f1..0d74b3a 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + rename_func, str_position_sql, trim_sql, ) @@ -260,6 +261,16 @@ class Postgres(Dialect): "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } + BITWISE = { + **parser.Parser.BITWISE, # type: ignore + TokenType.HASH: exp.BitwiseXor, + } + + FACTOR = { + **parser.Parser.FACTOR, # type: ignore + TokenType.CARET: exp.Pow, + } + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -273,6 +284,7 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: preprocess( [ _auto_increment_to_serial, @@ -285,11 +297,13 @@ class Postgres(Dialect): exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), + exp.Pow: lambda self, e: self.binary(e, "^"), exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, + exp.LogicalOr: rename_func("BOOL_OR"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index a79a9f9..8175d6f 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -174,6 +174,7 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "FROM_UNIXTIME": _from_unixtime, + "NOW": exp.CurrentTimestamp.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), @@ -194,7 +195,6 @@ class Presto(Dialect): FUNCTION_PARSERS.pop("TRIM") class Generator(generator.Generator): - STRUCT_DELIMITER = ("(", ")") ROOT_PROPERTIES = {exp.SchemaCommentProperty} diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index afd7913..7da881f 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -93,7 +93,7 @@ class Redshift(Postgres): rows = [tuple_exp.expressions for tuple_exp in expression.expressions] selects = [] for i, row in enumerate(rows): - if i == 0: + if i == 0 and expression.alias: row = [ exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 6225a53..db72a34 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -178,11 +178,6 @@ class Snowflake(Dialect): ), } - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), - } - class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] ESCAPES = ["\\", "'"] @@ -195,6 +190,7 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "EXCLUDE": TokenType.EXCEPT, + "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 42d34c2..fc711ab 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, parser -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func +from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get @@ -122,6 +122,7 @@ class Spark(Hive): exp.Reduce: rename_func("AGGREGATE"), exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", + exp.Trim: trim_sql, exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), exp.LogicalOr: rename_func("BOOL_OR"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index f9751ca..7c1a116 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -230,6 +230,7 @@ class Expression(metaclass=_Expression): Returns a deep copy of the expression. """ new = deepcopy(self) + new.parent = self.parent for item, parent, _ in new.bfs(): if isinstance(item, Expression) and parent: item.parent = parent @@ -759,6 +760,10 @@ class Create(Expression): "this": True, "kind": True, "expression": False, + "set": False, + "multiset": False, + "global_temporary": False, + "volatile": False, "exists": False, "properties": False, "temporary": False, @@ -1082,7 +1087,7 @@ class LoadData(Expression): class Partition(Expression): - pass + arg_types = {"expressions": True} class Fetch(Expression): @@ -1232,6 +1237,18 @@ class Lateral(UDTF): arg_types = {"this": True, "view": False, "outer": False, "alias": False} +class MatchRecognize(Expression): + arg_types = { + "partition_by": False, + "order": False, + "measures": False, + "rows": False, + "after": False, + "pattern": False, + "define": False, + } + + # Clickhouse FROM FINAL modifier # https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier class Final(Expression): @@ -1357,8 +1374,58 @@ class SerdeProperties(Property): arg_types = {"expressions": True} +class FallbackProperty(Property): + arg_types = {"no": True, "protection": False} + + +class WithJournalTableProperty(Property): + arg_types = {"this": True} + + +class LogProperty(Property): + arg_types = {"no": True} + + +class JournalProperty(Property): + arg_types = {"no": True, "dual": False, "before": False} + + +class AfterJournalProperty(Property): + arg_types = {"no": True, "dual": False, "local": False} + + +class ChecksumProperty(Property): + arg_types = {"on": False, "default": False} + + +class FreespaceProperty(Property): + arg_types = {"this": True, "percent": False} + + +class MergeBlockRatioProperty(Property): + arg_types = {"this": False, "no": False, "default": False, "percent": False} + + +class DataBlocksizeProperty(Property): + arg_types = {"size": False, "units": False, "min": False, "default": False} + + +class BlockCompressionProperty(Property): + arg_types = {"autotemp": False, "always": False, "default": True, "manual": True, "never": True} + + +class IsolatedLoadingProperty(Property): + arg_types = { + "no": True, + "concurrent": True, + "for_all": True, + "for_insert": True, + "for_none": True, + } + + class Properties(Expression): - arg_types = {"expressions": True} + arg_types = {"expressions": True, "before": False} NAME_TO_PROPERTY = { "AUTO_INCREMENT": AutoIncrementProperty, @@ -1510,6 +1577,7 @@ class Subqueryable(Unionable): QUERY_MODIFIERS = { + "match": False, "laterals": False, "joins": False, "pivots": False, @@ -2459,6 +2527,10 @@ class AddConstraint(Expression): arg_types = {"this": False, "expression": False, "enforced": False} +class DropPartition(Expression): + arg_types = {"expressions": True, "exists": False} + + # Binary expressions like (ADD a b) class Binary(Expression): arg_types = {"this": True, "expression": True} @@ -2540,6 +2612,10 @@ class Escape(Binary): pass +class Glob(Binary, Predicate): + pass + + class GT(Binary, Predicate): pass @@ -3126,8 +3202,7 @@ class Posexplode(Func): pass -class Pow(Func): - arg_types = {"this": True, "power": True} +class Pow(Binary, Func): _sql_names = ["POWER", "POW"] @@ -3361,7 +3436,7 @@ class Year(Func): class Use(Expression): - pass + arg_types = {"this": True, "kind": False} class Merge(Expression): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b398d8e..3f3365a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -65,6 +65,8 @@ class Generator: exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.VolatilityProperty: lambda self, e: e.name, + exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", + exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", } # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed @@ -97,6 +99,20 @@ class Generator: STRUCT_DELIMITER = ("<", ">") + BEFORE_PROPERTIES = { + exp.FallbackProperty, + exp.WithJournalTableProperty, + exp.LogProperty, + exp.JournalProperty, + exp.AfterJournalProperty, + exp.ChecksumProperty, + exp.FreespaceProperty, + exp.MergeBlockRatioProperty, + exp.DataBlocksizeProperty, + exp.BlockCompressionProperty, + exp.IsolatedLoadingProperty, + } + ROOT_PROPERTIES = { exp.ReturnsProperty, exp.LanguageProperty, @@ -113,8 +129,6 @@ class Generator: exp.TableFormatProperty, } - WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint) - WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -122,7 +136,6 @@ class Generator: "time_mapping", "time_trie", "pretty", - "configured_pretty", "quote_start", "quote_end", "identifier_start", @@ -177,7 +190,6 @@ class Generator: self.time_mapping = time_mapping or {} self.time_trie = time_trie self.pretty = pretty if pretty is not None else sqlglot.pretty - self.configured_pretty = self.pretty self.quote_start = quote_start or "'" self.quote_end = quote_end or "'" self.identifier_start = identifier_start or '"' @@ -442,8 +454,20 @@ class Generator: return "UNIQUE" def create_sql(self, expression: exp.Create) -> str: - this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() + has_before_properties = expression.args.get("properties") + has_before_properties = ( + has_before_properties.args.get("before") if has_before_properties else None + ) + if kind == "TABLE" and has_before_properties: + this_name = self.sql(expression.this, "this") + this_properties = self.sql(expression, "properties") + this_schema = f"({self.expressions(expression.this)})" + this = f"{this_name}, {this_properties} {this_schema}" + properties = "" + else: + this = self.sql(expression, "this") + properties = self.sql(expression, "properties") begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" @@ -456,7 +480,10 @@ class Generator: exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" - properties = self.sql(expression, "properties") + set_ = " SET" if expression.args.get("set") else "" + multiset = " MULTISET" if expression.args.get("multiset") else "" + global_temporary = " GLOBAL TEMPORARY" if expression.args.get("global_temporary") else "" + volatile = " VOLATILE" if expression.args.get("volatile") else "" data = expression.args.get("data") if data is None: data = "" @@ -475,7 +502,7 @@ class Generator: indexes = expression.args.get("indexes") index_sql = "" - if indexes is not None: + if indexes: indexes_sql = [] for index in indexes: ind_unique = " UNIQUE" if index.args.get("unique") else "" @@ -500,6 +527,10 @@ class Generator: external, unique, materialized, + set_, + multiset, + global_temporary, + volatile, ) ) no_schema_binding = ( @@ -569,13 +600,14 @@ class Generator: def delete_sql(self, expression: exp.Delete) -> str: this = self.sql(expression, "this") + this = f" FROM {this}" if this else "" using_sql = ( f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else "" ) where_sql = self.sql(expression, "where") - sql = f"DELETE FROM {this}{using_sql}{where_sql}" + sql = f"DELETE{this}{using_sql}{where_sql}" return self.prepend_ctes(expression, sql) def drop_sql(self, expression: exp.Drop) -> str: @@ -630,28 +662,27 @@ class Generator: return f"N{self.sql(expression, 'this')}" def partition_sql(self, expression: exp.Partition) -> str: - keys = csv( - *[ - f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name - for prop in expression.this - ] - ) - return f"PARTITION({keys})" + return f"PARTITION({self.expressions(expression)})" def properties_sql(self, expression: exp.Properties) -> str: + before_properties = [] root_properties = [] with_properties = [] for p in expression.expressions: p_class = p.__class__ - if p_class in self.WITH_PROPERTIES: + if p_class in self.BEFORE_PROPERTIES: + before_properties.append(p) + elif p_class in self.WITH_PROPERTIES: with_properties.append(p) elif p_class in self.ROOT_PROPERTIES: root_properties.append(p) - return self.root_properties( - exp.Properties(expressions=root_properties) - ) + self.with_properties(exp.Properties(expressions=with_properties)) + return ( + self.properties(exp.Properties(expressions=before_properties), before=True) + + self.root_properties(exp.Properties(expressions=root_properties)) + + self.with_properties(exp.Properties(expressions=with_properties)) + ) def root_properties(self, properties: exp.Properties) -> str: if properties.expressions: @@ -659,13 +690,17 @@ class Generator: return "" def properties( - self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = "" + self, + properties: exp.Properties, + prefix: str = "", + sep: str = ", ", + suffix: str = "", + before: bool = False, ) -> str: if properties.expressions: expressions = self.expressions(properties, sep=sep, indent=False) - return ( - f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}" - ) + expressions = expressions if before else self.wrap(expressions) + return f"{prefix}{' ' if prefix and prefix != ' ' else ''}{expressions}{suffix}" return "" def with_properties(self, properties: exp.Properties) -> str: @@ -687,6 +722,98 @@ class Generator: options = f" {options}" if options else "" return f"LIKE {self.sql(expression, 'this')}{options}" + def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str: + no = "NO " if expression.args.get("no") else "" + protection = " PROTECTION" if expression.args.get("protection") else "" + return f"{no}FALLBACK{protection}" + + def journalproperty_sql(self, expression: exp.JournalProperty) -> str: + no = "NO " if expression.args.get("no") else "" + dual = "DUAL " if expression.args.get("dual") else "" + before = "BEFORE " if expression.args.get("before") else "" + return f"{no}{dual}{before}JOURNAL" + + def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str: + freespace = self.sql(expression, "this") + percent = " PERCENT" if expression.args.get("percent") else "" + return f"FREESPACE={freespace}{percent}" + + def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str: + no = "NO " if expression.args.get("no") else "" + dual = "DUAL " if expression.args.get("dual") else "" + local = "" + if expression.args.get("local") is not None: + local = "LOCAL " if expression.args.get("local") else "NOT LOCAL " + return f"{no}{dual}{local}AFTER JOURNAL" + + def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str: + if expression.args.get("default"): + property = "DEFAULT" + elif expression.args.get("on"): + property = "ON" + else: + property = "OFF" + return f"CHECKSUM={property}" + + def mergeblockratioproperty_sql(self, expression: exp.MergeBlockRatioProperty) -> str: + if expression.args.get("no"): + return "NO MERGEBLOCKRATIO" + if expression.args.get("default"): + return "DEFAULT MERGEBLOCKRATIO" + + percent = " PERCENT" if expression.args.get("percent") else "" + return f"MERGEBLOCKRATIO={self.sql(expression, 'this')}{percent}" + + def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str: + default = expression.args.get("default") + min = expression.args.get("min") + if default is not None or min is not None: + if default: + property = "DEFAULT" + elif min: + property = "MINIMUM" + else: + property = "MAXIMUM" + return f"{property} DATABLOCKSIZE" + else: + units = expression.args.get("units") + units = f" {units}" if units else "" + return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" + + def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str: + autotemp = expression.args.get("autotemp") + always = expression.args.get("always") + default = expression.args.get("default") + manual = expression.args.get("manual") + never = expression.args.get("never") + + if autotemp is not None: + property = f"AUTOTEMP({self.expressions(autotemp)})" + elif always: + property = "ALWAYS" + elif default: + property = "DEFAULT" + elif manual: + property = "MANUAL" + elif never: + property = "NEVER" + return f"BLOCKCOMPRESSION={property}" + + def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str: + no = expression.args.get("no") + no = " NO" if no else "" + concurrent = expression.args.get("concurrent") + concurrent = " CONCURRENT" if concurrent else "" + + for_ = "" + if expression.args.get("for_all"): + for_ = " FOR ALL" + elif expression.args.get("for_insert"): + for_ = " FOR INSERT" + elif expression.args.get("for_none"): + for_ = " FOR NONE" + return f"WITH{no}{concurrent} ISOLATED LOADING{for_}" + def insert_sql(self, expression: exp.Insert) -> str: overwrite = expression.args.get("overwrite") @@ -833,10 +960,21 @@ class Generator: grouping_sets = ( f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else "" ) - cube = self.expressions(expression, key="cube", indent=False) - cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" - rollup = self.expressions(expression, key="rollup", indent=False) - rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" + + cube = expression.args.get("cube") + if cube is True: + cube = self.seg("WITH CUBE") + else: + cube = self.expressions(expression, key="cube", indent=False) + cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" + + rollup = expression.args.get("rollup") + if rollup is True: + rollup = self.seg("WITH ROLLUP") + else: + rollup = self.expressions(expression, key="rollup", indent=False) + rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" + return f"{group_by}{grouping_sets}{cube}{rollup}" def having_sql(self, expression: exp.Having) -> str: @@ -980,10 +1118,37 @@ class Generator: return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" + def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: + partition = self.partition_by_sql(expression) + order = self.sql(expression, "order") + measures = self.sql(expression, "measures") + measures = self.seg(f"MEASURES {measures}") if measures else "" + rows = self.sql(expression, "rows") + rows = self.seg(rows) if rows else "" + after = self.sql(expression, "after") + after = self.seg(after) if after else "" + pattern = self.sql(expression, "pattern") + pattern = self.seg(f"PATTERN ({pattern})") if pattern else "" + define = self.sql(expression, "define") + define = self.seg(f"DEFINE {define}") if define else "" + body = "".join( + ( + partition, + order, + measures, + rows, + after, + pattern, + define, + ) + ) + return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}" + def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("joins") or []], + self.sql(expression, "match"), *[self.sql(sql) for sql in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), @@ -1092,8 +1257,7 @@ class Generator: def window_sql(self, expression: exp.Window) -> str: this = self.sql(expression, "this") - partition = self.expressions(expression, key="partition_by", flat=True) - partition = f"PARTITION BY {partition}" if partition else "" + partition = self.partition_by_sql(expression) order = expression.args.get("order") order_sql = self.order_sql(order, flat=True) if order else "" @@ -1113,6 +1277,10 @@ class Generator: return f"{this} ({window_args.strip()})" + def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str: + partition = self.expressions(expression, key="partition_by", flat=True) + return f"PARTITION BY {partition}" if partition else "" + def window_spec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") @@ -1386,16 +1554,19 @@ class Generator: actions = self.expressions(expression, "actions", prefix="ADD COLUMN ") elif isinstance(actions[0], exp.Schema): actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") - elif isinstance(actions[0], exp.Drop): - actions = self.expressions(expression, "actions") - elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION): - actions = self.sql(actions[0]) + elif isinstance(actions[0], exp.Delete): + actions = self.expressions(expression, "actions", flat=True) else: - self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") + actions = self.expressions(expression, "actions") exists = " IF EXISTS" if expression.args.get("exists") else "" return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + def droppartition_sql(self, expression: exp.DropPartition) -> str: + expressions = self.expressions(expression) + exists = " IF EXISTS " if expression.args.get("exists") else " " + return f"DROP{exists}{expressions}" + def addconstraint_sql(self, expression: exp.AddConstraint) -> str: this = self.sql(expression, "this") expression_ = self.sql(expression, "expression") @@ -1447,6 +1618,9 @@ class Generator: def escape_sql(self, expression: exp.Escape) -> str: return self.binary(expression, "ESCAPE") + def glob_sql(self, expression: exp.Glob) -> str: + return self.binary(expression, "GLOB") + def gt_sql(self, expression: exp.GT) -> str: return self.binary(expression, ">") @@ -1499,7 +1673,11 @@ class Generator: return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" def use_sql(self, expression: exp.Use) -> str: - return f"USE {self.sql(expression, 'this')}" + kind = self.sql(expression, "kind") + kind = f" {kind}" if kind else "" + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"USE{kind}{this}" def binary(self, expression: exp.Binary, op: str) -> str: return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py index ba562df..86f0c2d 100644 --- a/sqlglot/optimizer/expand_multi_table_selects.py +++ b/sqlglot/optimizer/expand_multi_table_selects.py @@ -2,6 +2,14 @@ from sqlglot import exp def expand_multi_table_selects(expression): + """ + Replace multiple FROM expressions with JOINs. + + Example: + >>> from sqlglot import parse_one + >>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql() + 'SELECT * FROM x CROSS JOIN y' + """ for from_ in expression.find_all(exp.From): parent = from_.parent diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 5bd7b30..5d78353 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -11,7 +11,7 @@ def isolate_table_selects(expression, schema=None): if len(scope.selected_sources) == 1: continue - for (_, source) in scope.selected_sources.values(): + for _, source in scope.selected_sources.values(): if not isinstance(source, exp.Table) or not schema.column_names(source): continue diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index fd69832..dc5ce44 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -6,6 +6,11 @@ from sqlglot.optimizer.simplify import simplify def optimize_joins(expression): """ Removes cross joins if possible and reorder joins based on predicate dependencies. + + Example: + >>> from sqlglot import parse_one + >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() + 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' """ for select in expression.find_all(exp.Select): references = {} diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 5258c2b..766e059 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -64,7 +64,6 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() for rule in rules: - # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames rule_kwargs = { diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 42777d1..6229105 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -175,13 +175,9 @@ class Parser(metaclass=_Parser): TokenType.DEFAULT, TokenType.DELETE, TokenType.DESCRIBE, - TokenType.DETERMINISTIC, TokenType.DIV, - TokenType.DISTKEY, - TokenType.DISTSTYLE, TokenType.END, TokenType.EXECUTE, - TokenType.ENGINE, TokenType.ESCAPE, TokenType.FALSE, TokenType.FIRST, @@ -194,13 +190,10 @@ class Parser(metaclass=_Parser): TokenType.IF, TokenType.INDEX, TokenType.ISNULL, - TokenType.IMMUTABLE, TokenType.INTERVAL, TokenType.LAZY, - TokenType.LANGUAGE, TokenType.LEADING, TokenType.LOCAL, - TokenType.LOCATION, TokenType.MATERIALIZED, TokenType.MERGE, TokenType.NATURAL, @@ -209,13 +202,11 @@ class Parser(metaclass=_Parser): TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, - TokenType.PARTITIONED_BY, TokenType.PERCENT, TokenType.PIVOT, TokenType.PRECEDING, TokenType.RANGE, TokenType.REFERENCES, - TokenType.RETURNS, TokenType.ROW, TokenType.ROWS, TokenType.SCHEMA, @@ -225,10 +216,7 @@ class Parser(metaclass=_Parser): TokenType.SET, TokenType.SHOW, TokenType.SORTKEY, - TokenType.STABLE, - TokenType.STORED, TokenType.TABLE, - TokenType.TABLE_FORMAT, TokenType.TEMPORARY, TokenType.TOP, TokenType.TRAILING, @@ -237,7 +225,6 @@ class Parser(metaclass=_Parser): TokenType.UNIQUE, TokenType.UNLOGGED, TokenType.UNPIVOT, - TokenType.PROPERTIES, TokenType.PROCEDURE, TokenType.VIEW, TokenType.VOLATILE, @@ -448,7 +435,12 @@ class Parser(metaclass=_Parser): TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.UPDATE: lambda self: self._parse_update(), - TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), + TokenType.USE: lambda self: self.expression( + exp.Use, + kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA")) + and exp.Var(this=self._prev.text), + this=self._parse_table(schema=False), + ), } UNARY_PARSERS = { @@ -492,6 +484,9 @@ class Parser(metaclass=_Parser): RANGE_PARSERS = { TokenType.BETWEEN: lambda self, this: self._parse_between(this), + TokenType.GLOB: lambda self, this: self._parse_escape( + self.expression(exp.Glob, this=this, expression=self._parse_bitwise()) + ), TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: lambda self, this: self._parse_escape( @@ -512,45 +507,66 @@ class Parser(metaclass=_Parser): } PROPERTY_PARSERS = { - TokenType.AUTO_INCREMENT: lambda self: self._parse_property_assignment( - exp.AutoIncrementProperty - ), - TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), - TokenType.LOCATION: lambda self: self._parse_property_assignment(exp.LocationProperty), - TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), - TokenType.SCHEMA_COMMENT: lambda self: self._parse_property_assignment( - exp.SchemaCommentProperty - ), - TokenType.STORED: lambda self: self._parse_property_assignment(exp.FileFormatProperty), - TokenType.DISTKEY: lambda self: self._parse_distkey(), - TokenType.DISTSTYLE: lambda self: self._parse_property_assignment(exp.DistStyleProperty), - TokenType.SORTKEY: lambda self: self._parse_sortkey(), - TokenType.LIKE: lambda self: self._parse_create_like(), - TokenType.RETURNS: lambda self: self._parse_returns(), - TokenType.ROW: lambda self: self._parse_row(), - TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), - TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), - TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty), - TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment( - exp.TableFormatProperty - ), - TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty), - TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty), - TokenType.EXECUTE: lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), - TokenType.DETERMINISTIC: lambda self: self.expression( + "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), + "CHARACTER SET": lambda self: self._parse_character_set(), + "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), + "PARTITION BY": lambda self: self._parse_partitioned_by(), + "PARTITIONED BY": lambda self: self._parse_partitioned_by(), + "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), + "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), + "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "DISTKEY": lambda self: self._parse_distkey(), + "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), + "SORTKEY": lambda self: self._parse_sortkey(), + "LIKE": lambda self: self._parse_create_like(), + "RETURNS": lambda self: self._parse_returns(), + "ROW": lambda self: self._parse_row(), + "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), + "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), + "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), + "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), + "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), + "DETERMINISTIC": lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), - TokenType.IMMUTABLE: lambda self: self.expression( + "IMMUTABLE": lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") ), - TokenType.STABLE: lambda self: self.expression( + "STABLE": lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("STABLE") ), - TokenType.VOLATILE: lambda self: self.expression( + "VOLATILE": lambda self: self.expression( exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") ), - TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property), - TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property), + "WITH": lambda self: self._parse_with_property(), + "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), + "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"), + "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"), + "BEFORE": lambda self: self._parse_journal( + no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" + ), + "JOURNAL": lambda self: self._parse_journal( + no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" + ), + "AFTER": lambda self: self._parse_afterjournal( + no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" + ), + "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True), + "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False), + "CHECKSUM": lambda self: self._parse_checksum(), + "FREESPACE": lambda self: self._parse_freespace(), + "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio( + no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT" + ), + "MIN": lambda self: self._parse_datablocksize(), + "MINIMUM": lambda self: self._parse_datablocksize(), + "MAX": lambda self: self._parse_datablocksize(), + "MAXIMUM": lambda self: self._parse_datablocksize(), + "DATABLOCKSIZE": lambda self: self._parse_datablocksize( + default=self._prev.text.upper() == "DEFAULT" + ), + "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), } CONSTRAINT_PARSERS = { @@ -580,6 +596,7 @@ class Parser(metaclass=_Parser): } QUERY_MODIFIER_PARSERS = { + "match": lambda self: self._parse_match_recognize(), "where": lambda self: self._parse_where(), "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), @@ -627,7 +644,6 @@ class Parser(metaclass=_Parser): "max_errors", "null_ordering", "_tokens", - "_chunks", "_index", "_curr", "_next", @@ -660,7 +676,6 @@ class Parser(metaclass=_Parser): self.sql = "" self.errors = [] self._tokens = [] - self._chunks = [[]] self._index = 0 self._curr = None self._next = None @@ -728,17 +743,18 @@ class Parser(metaclass=_Parser): self.reset() self.sql = sql or "" total = len(raw_tokens) + chunks: t.List[t.List[Token]] = [[]] for i, token in enumerate(raw_tokens): if token.token_type == TokenType.SEMICOLON: if i < total - 1: - self._chunks.append([]) + chunks.append([]) else: - self._chunks[-1].append(token) + chunks[-1].append(token) expressions = [] - for tokens in self._chunks: + for tokens in chunks: self._index = -1 self._tokens = tokens self._advance() @@ -771,7 +787,7 @@ class Parser(metaclass=_Parser): error level setting. """ token = token or self._curr or self._prev or Token.string("") - start = self._find_token(token, self.sql) + start = self._find_token(token) end = start + len(token.text) start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] @@ -833,13 +849,16 @@ class Parser(metaclass=_Parser): for error_message in expression.error_messages(args): self.raise_error(error_message) - def _find_token(self, token: Token, sql: str) -> int: + def _find_sql(self, start: Token, end: Token) -> str: + return self.sql[self._find_token(start) : self._find_token(end)] + + def _find_token(self, token: Token) -> int: line = 1 col = 1 index = 0 while line < token.line or col < token.col: - if Tokenizer.WHITE_SPACE.get(sql[index]) == TokenType.BREAK: + if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK: line += 1 col = 1 else: @@ -911,6 +930,10 @@ class Parser(metaclass=_Parser): def _parse_create(self) -> t.Optional[exp.Expression]: replace = self._match_pair(TokenType.OR, TokenType.REPLACE) + set_ = self._match(TokenType.SET) # Teradata + multiset = self._match_text_seq("MULTISET") # Teradata + global_temporary = self._match_text_seq("GLOBAL", "TEMPORARY") # Teradata + volatile = self._match(TokenType.VOLATILE) # Teradata temporary = self._match(TokenType.TEMPORARY) transient = self._match_text_seq("TRANSIENT") external = self._match_text_seq("EXTERNAL") @@ -954,10 +977,18 @@ class Parser(metaclass=_Parser): TokenType.VIEW, TokenType.SCHEMA, ): - this = self._parse_table(schema=True) - properties = self._parse_properties() - if self._match(TokenType.ALIAS): - expression = self._parse_ddl_select() + table_parts = self._parse_table_parts(schema=True) + + if self._match(TokenType.COMMA): # comma-separated properties before schema definition + properties = self._parse_properties(before=True) + + this = self._parse_schema(this=table_parts) + + if not properties: # properties after schema definition + properties = self._parse_properties() + + self._match(TokenType.ALIAS) + expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: if self._match_text_seq("WITH", "DATA"): @@ -988,6 +1019,10 @@ class Parser(metaclass=_Parser): this=this, kind=create_token.text, expression=expression, + set=set_, + multiset=multiset, + global_temporary=global_temporary, + volatile=volatile, exists=exists, properties=properties, temporary=temporary, @@ -1004,9 +1039,19 @@ class Parser(metaclass=_Parser): begin=begin, ) + def _parse_property_before(self) -> t.Optional[exp.Expression]: + self._match_text_seq("NO") + self._match_text_seq("DUAL") + self._match_text_seq("DEFAULT") + + if self.PROPERTY_PARSERS.get(self._curr.text.upper()): + return self.PROPERTY_PARSERS[self._curr.text.upper()](self) + + return None + def _parse_property(self) -> t.Optional[exp.Expression]: - if self._match_set(self.PROPERTY_PARSERS): - return self.PROPERTY_PARSERS[self._prev.token_type](self) + if self._match_texts(self.PROPERTY_PARSERS): + return self.PROPERTY_PARSERS[self._prev.text.upper()](self) if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): return self._parse_character_set(True) @@ -1033,6 +1078,166 @@ class Parser(metaclass=_Parser): this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) + def _parse_properties(self, before=None) -> t.Optional[exp.Expression]: + properties = [] + + while True: + if before: + self._match(TokenType.COMMA) + identified_property = self._parse_property_before() + else: + identified_property = self._parse_property() + + if not identified_property: + break + for p in ensure_collection(identified_property): + properties.append(p) + + if properties: + return self.expression(exp.Properties, expressions=properties, before=before) + + return None + + def _parse_fallback(self, no=False) -> exp.Expression: + self._match_text_seq("FALLBACK") + return self.expression( + exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") + ) + + def _parse_with_property( + self, + ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: + if self._match(TokenType.L_PAREN, advance=False): + return self._parse_wrapped_csv(self._parse_property) + + if not self._next: + return None + + if self._next.text.upper() == "JOURNAL": + return self._parse_withjournaltable() + + return self._parse_withisolatedloading() + + def _parse_withjournaltable(self) -> exp.Expression: + self._match_text_seq("WITH", "JOURNAL", "TABLE") + self._match(TokenType.EQ) + return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) + + def _parse_log(self, no=False) -> exp.Expression: + self._match_text_seq("LOG") + return self.expression(exp.LogProperty, no=no) + + def _parse_journal(self, no=False, dual=False) -> exp.Expression: + before = self._match_text_seq("BEFORE") + self._match_text_seq("JOURNAL") + return self.expression(exp.JournalProperty, no=no, dual=dual, before=before) + + def _parse_afterjournal(self, no=False, dual=False, local=None) -> exp.Expression: + self._match_text_seq("NOT") + self._match_text_seq("LOCAL") + self._match_text_seq("AFTER", "JOURNAL") + return self.expression(exp.AfterJournalProperty, no=no, dual=dual, local=local) + + def _parse_checksum(self) -> exp.Expression: + self._match_text_seq("CHECKSUM") + self._match(TokenType.EQ) + + on = None + if self._match(TokenType.ON): + on = True + elif self._match_text_seq("OFF"): + on = False + default = self._match(TokenType.DEFAULT) + + return self.expression( + exp.ChecksumProperty, + on=on, + default=default, + ) + + def _parse_freespace(self) -> exp.Expression: + self._match_text_seq("FREESPACE") + self._match(TokenType.EQ) + return self.expression( + exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT) + ) + + def _parse_mergeblockratio(self, no=False, default=False) -> exp.Expression: + self._match_text_seq("MERGEBLOCKRATIO") + if self._match(TokenType.EQ): + return self.expression( + exp.MergeBlockRatioProperty, + this=self._parse_number(), + percent=self._match(TokenType.PERCENT), + ) + else: + return self.expression( + exp.MergeBlockRatioProperty, + no=no, + default=default, + ) + + def _parse_datablocksize(self, default=None) -> exp.Expression: + if default: + self._match_text_seq("DATABLOCKSIZE") + return self.expression(exp.DataBlocksizeProperty, default=True) + elif self._match_texts(("MIN", "MINIMUM")): + self._match_text_seq("DATABLOCKSIZE") + return self.expression(exp.DataBlocksizeProperty, min=True) + elif self._match_texts(("MAX", "MAXIMUM")): + self._match_text_seq("DATABLOCKSIZE") + return self.expression(exp.DataBlocksizeProperty, min=False) + + self._match_text_seq("DATABLOCKSIZE") + self._match(TokenType.EQ) + size = self._parse_number() + units = None + if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): + units = self._prev.text + return self.expression(exp.DataBlocksizeProperty, size=size, units=units) + + def _parse_blockcompression(self) -> exp.Expression: + self._match_text_seq("BLOCKCOMPRESSION") + self._match(TokenType.EQ) + always = self._match(TokenType.ALWAYS) + manual = self._match_text_seq("MANUAL") + never = self._match_text_seq("NEVER") + default = self._match_text_seq("DEFAULT") + autotemp = None + if self._match_text_seq("AUTOTEMP"): + autotemp = self._parse_schema() + + return self.expression( + exp.BlockCompressionProperty, + always=always, + manual=manual, + never=never, + default=default, + autotemp=autotemp, + ) + + def _parse_withisolatedloading(self) -> exp.Expression: + self._match(TokenType.WITH) + no = self._match_text_seq("NO") + concurrent = self._match_text_seq("CONCURRENT") + self._match_text_seq("ISOLATED", "LOADING") + for_all = self._match_text_seq("FOR", "ALL") + for_insert = self._match_text_seq("FOR", "INSERT") + for_none = self._match_text_seq("FOR", "NONE") + return self.expression( + exp.IsolatedLoadingProperty, + no=no, + concurrent=concurrent, + for_all=for_all, + for_insert=for_insert, + for_none=for_none, + ) + + def _parse_partition_by(self) -> t.List[t.Optional[exp.Expression]]: + if self._match(TokenType.PARTITION_BY): + return self._parse_csv(self._parse_conjunction) + return [] + def _parse_partitioned_by(self) -> exp.Expression: self._match(TokenType.EQ) return self.expression( @@ -1093,21 +1298,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_properties(self) -> t.Optional[exp.Expression]: - properties = [] - - while True: - identified_property = self._parse_property() - if not identified_property: - break - for p in ensure_collection(identified_property): - properties.append(p) - - if properties: - return self.expression(exp.Properties, expressions=properties) - - return None - def _parse_describe(self) -> exp.Expression: kind = self._match_set(self.CREATABLES) and self._prev.text this = self._parse_table() @@ -1248,11 +1438,9 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.PARTITION): return None - def parse_values() -> exp.Property: - props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ) - return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1)) - - return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) + return self.expression( + exp.Partition, expressions=self._parse_wrapped_csv(self._parse_conjunction) + ) def _parse_value(self) -> exp.Expression: if self._match(TokenType.L_PAREN): @@ -1360,8 +1548,7 @@ class Parser(metaclass=_Parser): if not alias or not alias.this: self.raise_error("Expected CTE to have alias") - if not self._match(TokenType.ALIAS): - self.raise_error("Expected AS in CTE") + self._match(TokenType.ALIAS) return self.expression( exp.CTE, @@ -1376,10 +1563,11 @@ class Parser(metaclass=_Parser): alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS ) + index = self._index if self._match(TokenType.L_PAREN): columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var())) - self._match_r_paren() + self._match_r_paren() if columns else self._retreat(index) else: columns = None @@ -1452,6 +1640,87 @@ class Parser(metaclass=_Parser): exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) ) + def _parse_match_recognize(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.MATCH_RECOGNIZE): + return None + self._match_l_paren() + + partition = self._parse_partition_by() + order = self._parse_order() + measures = ( + self._parse_alias(self._parse_conjunction()) + if self._match_text_seq("MEASURES") + else None + ) + + if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): + rows = exp.Var(this="ONE ROW PER MATCH") + elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): + text = "ALL ROWS PER MATCH" + if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): + text += f" SHOW EMPTY MATCHES" + elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"): + text += f" OMIT EMPTY MATCHES" + elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): + text += f" WITH UNMATCHED ROWS" + rows = exp.Var(this=text) + else: + rows = None + + if self._match_text_seq("AFTER", "MATCH", "SKIP"): + text = "AFTER MATCH SKIP" + if self._match_text_seq("PAST", "LAST", "ROW"): + text += f" PAST LAST ROW" + elif self._match_text_seq("TO", "NEXT", "ROW"): + text += f" TO NEXT ROW" + elif self._match_text_seq("TO", "FIRST"): + text += f" TO FIRST {self._advance_any().text}" # type: ignore + elif self._match_text_seq("TO", "LAST"): + text += f" TO LAST {self._advance_any().text}" # type: ignore + after = exp.Var(this=text) + else: + after = None + + if self._match_text_seq("PATTERN"): + self._match_l_paren() + + if not self._curr: + self.raise_error("Expecting )", self._curr) + + paren = 1 + start = self._curr + + while self._curr and paren > 0: + if self._curr.token_type == TokenType.L_PAREN: + paren += 1 + if self._curr.token_type == TokenType.R_PAREN: + paren -= 1 + self._advance() + if paren > 0: + self.raise_error("Expecting )", self._curr) + if not self._curr: + self.raise_error("Expecting pattern", self._curr) + end = self._prev + pattern = exp.Var(this=self._find_sql(start, end)) + else: + pattern = None + + define = ( + self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None + ) + self._match_r_paren() + + return self.expression( + exp.MatchRecognize, + partition_by=partition, + order=order, + measures=measures, + rows=rows, + after=after, + pattern=pattern, + define=define, + ) + def _parse_lateral(self) -> t.Optional[exp.Expression]: outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) @@ -1772,12 +2041,19 @@ class Parser(metaclass=_Parser): if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None + expressions = self._parse_csv(self._parse_conjunction) + grouping_sets = self._parse_grouping_sets() + + with_ = self._match(TokenType.WITH) + cube = self._match(TokenType.CUBE) and (with_ or self._parse_wrapped_id_vars()) + rollup = self._match(TokenType.ROLLUP) and (with_ or self._parse_wrapped_id_vars()) + return self.expression( exp.Group, - expressions=self._parse_csv(self._parse_conjunction), - grouping_sets=self._parse_grouping_sets(), - cube=self._match(TokenType.CUBE) and self._parse_wrapped_id_vars(), - rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(), + expressions=expressions, + grouping_sets=grouping_sets, + cube=cube, + rollup=rollup, ) def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: @@ -1788,11 +2064,11 @@ class Parser(metaclass=_Parser): def _parse_grouping_set(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): - grouping_set = self._parse_csv(self._parse_id_var) + grouping_set = self._parse_csv(self._parse_column) self._match_r_paren() return self.expression(exp.Tuple, expressions=grouping_set) - return self._parse_id_var() + return self._parse_column() def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]: if not skip_having_token and not self._match(TokenType.HAVING): @@ -2268,7 +2544,6 @@ class Parser(metaclass=_Parser): args = self._parse_csv(self._parse_lambda) if function: - # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists. if count_params(function) == 2: @@ -2541,9 +2816,10 @@ class Parser(metaclass=_Parser): return self.expression(exp.PrimaryKey, expressions=expressions, options=options) def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not self._match(TokenType.L_BRACKET): + if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this + bracket_kind = self._prev.token_type expressions: t.List[t.Optional[exp.Expression]] if self._match(TokenType.COLON): @@ -2551,14 +2827,19 @@ class Parser(metaclass=_Parser): else: expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction())) - if not this or this.name.upper() == "ARRAY": + # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs + if bracket_kind == TokenType.L_BRACE: + this = self.expression(exp.Struct, expressions=expressions) + elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: expressions = apply_index_offset(expressions, -self.index_offset) this = self.expression(exp.Bracket, this=this, expressions=expressions) - if not self._match(TokenType.R_BRACKET): + if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: self.raise_error("Expected ]") + elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE: + self.raise_error("Expected }") this.comments = self._prev_comments return self._parse_bracket(this) @@ -2727,7 +3008,7 @@ class Parser(metaclass=_Parser): position = self._prev.text.upper() expression = self._parse_term() - if self._match(TokenType.FROM): + if self._match_set((TokenType.FROM, TokenType.COMMA)): this = self._parse_term() else: this = expression @@ -2792,14 +3073,8 @@ class Parser(metaclass=_Parser): return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) - - partition = None - if self._match(TokenType.PARTITION_BY): - partition = self._parse_csv(self._parse_conjunction) - + partition = self._parse_partition_by() order = self._parse_order() - - spec = None kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text if kind: @@ -2816,6 +3091,8 @@ class Parser(metaclass=_Parser): end=end["value"], end_side=end["side"], ) + else: + spec = None self._match_r_paren() @@ -3060,6 +3337,12 @@ class Parser(metaclass=_Parser): def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html + def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression: + return self.expression( + exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists + ) + def _parse_add_constraint(self) -> t.Optional[exp.Expression]: this = None kind = self._prev.token_type @@ -3092,14 +3375,24 @@ class Parser(metaclass=_Parser): actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None index = self._index - if self._match_text_seq("ADD"): + if self._match(TokenType.DELETE): + actions = [self.expression(exp.Delete, where=self._parse_where())] + elif self._match_text_seq("ADD"): if self._match_set(self.ADD_CONSTRAINT_TOKENS): actions = self._parse_csv(self._parse_add_constraint) else: self._retreat(index) actions = self._parse_csv(self._parse_add_column) - elif self._match_text_seq("DROP", advance=False): - actions = self._parse_csv(self._parse_drop_column) + elif self._match_text_seq("DROP"): + partition_exists = self._parse_exists() + + if self._match(TokenType.PARTITION, advance=False): + actions = self._parse_csv( + lambda: self._parse_drop_partition(exists=partition_exists) + ) + else: + self._retreat(index) + actions = self._parse_csv(self._parse_drop_column) elif self._match_text_seq("RENAME", "TO"): actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True)) elif self._match_text_seq("ALTER"): diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 19dd1d6..8bdd338 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -22,6 +22,7 @@ class TokenType(AutoName): DCOLON = auto() SEMICOLON = auto() STAR = auto() + BACKSLASH = auto() SLASH = auto() LT = auto() LTE = auto() @@ -157,18 +158,14 @@ class TokenType(AutoName): DELETE = auto() DESC = auto() DESCRIBE = auto() - DETERMINISTIC = auto() DISTINCT = auto() DISTINCT_FROM = auto() - DISTKEY = auto() DISTRIBUTE_BY = auto() - DISTSTYLE = auto() DIV = auto() DROP = auto() ELSE = auto() ENCODE = auto() END = auto() - ENGINE = auto() ESCAPE = auto() EXCEPT = auto() EXECUTE = auto() @@ -182,10 +179,11 @@ class TokenType(AutoName): FOR = auto() FOREIGN_KEY = auto() FORMAT = auto() + FROM = auto() FULL = auto() FUNCTION = auto() - FROM = auto() GENERATED = auto() + GLOB = auto() GLOBAL = auto() GROUP_BY = auto() GROUPING_SETS = auto() @@ -195,7 +193,6 @@ class TokenType(AutoName): IF = auto() IGNORE_NULLS = auto() ILIKE = auto() - IMMUTABLE = auto() IN = auto() INDEX = auto() INNER = auto() @@ -217,8 +214,8 @@ class TokenType(AutoName): LIMIT = auto() LOAD_DATA = auto() LOCAL = auto() - LOCATION = auto() MAP = auto() + MATCH_RECOGNIZE = auto() MATERIALIZED = auto() MERGE = auto() MOD = auto() @@ -242,7 +239,6 @@ class TokenType(AutoName): OVERWRITE = auto() PARTITION = auto() PARTITION_BY = auto() - PARTITIONED_BY = auto() PERCENT = auto() PIVOT = auto() PLACEHOLDER = auto() @@ -258,7 +254,6 @@ class TokenType(AutoName): REPLACE = auto() RESPECT_NULLS = auto() REFERENCES = auto() - RETURNS = auto() RIGHT = auto() RLIKE = auto() ROLLBACK = auto() @@ -277,10 +272,7 @@ class TokenType(AutoName): SOME = auto() SORTKEY = auto() SORT_BY = auto() - STABLE = auto() - STORED = auto() STRUCT = auto() - TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() TOP = auto() @@ -414,6 +406,7 @@ class Tokenizer(metaclass=_Tokenizer): "+": TokenType.PLUS, ";": TokenType.SEMICOLON, "/": TokenType.SLASH, + "\\": TokenType.BACKSLASH, "*": TokenType.STAR, "~": TokenType.TILDA, "?": TokenType.PLACEHOLDER, @@ -448,9 +441,11 @@ class Tokenizer(metaclass=_Tokenizer): }, **{ f"{prefix}{key}": TokenType.BLOCK_END - for key in ("}}", "%}", "#}") + for key in ("%}", "#}") for prefix in ("", "+", "-") }, + "+}}": TokenType.BLOCK_END, + "-}}": TokenType.BLOCK_END, "/*+": TokenType.HINT, "==": TokenType.EQ, "::": TokenType.DCOLON, @@ -503,17 +498,13 @@ class Tokenizer(metaclass=_Tokenizer): "DELETE": TokenType.DELETE, "DESC": TokenType.DESC, "DESCRIBE": TokenType.DESCRIBE, - "DETERMINISTIC": TokenType.DETERMINISTIC, "DISTINCT": TokenType.DISTINCT, "DISTINCT FROM": TokenType.DISTINCT_FROM, - "DISTKEY": TokenType.DISTKEY, "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, - "DISTSTYLE": TokenType.DISTSTYLE, "DIV": TokenType.DIV, "DROP": TokenType.DROP, "ELSE": TokenType.ELSE, "END": TokenType.END, - "ENGINE": TokenType.ENGINE, "ESCAPE": TokenType.ESCAPE, "EXCEPT": TokenType.EXCEPT, "EXECUTE": TokenType.EXECUTE, @@ -530,13 +521,13 @@ class Tokenizer(metaclass=_Tokenizer): "FORMAT": TokenType.FORMAT, "FROM": TokenType.FROM, "GENERATED": TokenType.GENERATED, + "GLOB": TokenType.GLOB, "GROUP BY": TokenType.GROUP_BY, "GROUPING SETS": TokenType.GROUPING_SETS, "HAVING": TokenType.HAVING, "IDENTITY": TokenType.IDENTITY, "IF": TokenType.IF, "ILIKE": TokenType.ILIKE, - "IMMUTABLE": TokenType.IMMUTABLE, "IGNORE NULLS": TokenType.IGNORE_NULLS, "IN": TokenType.IN, "INDEX": TokenType.INDEX, @@ -548,7 +539,6 @@ class Tokenizer(metaclass=_Tokenizer): "IS": TokenType.IS, "ISNULL": TokenType.ISNULL, "JOIN": TokenType.JOIN, - "LANGUAGE": TokenType.LANGUAGE, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, "LEADING": TokenType.LEADING, @@ -557,7 +547,6 @@ class Tokenizer(metaclass=_Tokenizer): "LIMIT": TokenType.LIMIT, "LOAD DATA": TokenType.LOAD_DATA, "LOCAL": TokenType.LOCAL, - "LOCATION": TokenType.LOCATION, "MATERIALIZED": TokenType.MATERIALIZED, "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, @@ -582,8 +571,8 @@ class Tokenizer(metaclass=_Tokenizer): "OVERWRITE": TokenType.OVERWRITE, "PARTITION": TokenType.PARTITION, "PARTITION BY": TokenType.PARTITION_BY, - "PARTITIONED BY": TokenType.PARTITIONED_BY, - "PARTITIONED_BY": TokenType.PARTITIONED_BY, + "PARTITIONED BY": TokenType.PARTITION_BY, + "PARTITIONED_BY": TokenType.PARTITION_BY, "PERCENT": TokenType.PERCENT, "PIVOT": TokenType.PIVOT, "PRECEDING": TokenType.PRECEDING, @@ -596,7 +585,6 @@ class Tokenizer(metaclass=_Tokenizer): "REPLACE": TokenType.REPLACE, "RESPECT NULLS": TokenType.RESPECT_NULLS, "REFERENCES": TokenType.REFERENCES, - "RETURNS": TokenType.RETURNS, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, "ROLLBACK": TokenType.ROLLBACK, @@ -613,11 +601,7 @@ class Tokenizer(metaclass=_Tokenizer): "SOME": TokenType.SOME, "SORTKEY": TokenType.SORTKEY, "SORT BY": TokenType.SORT_BY, - "STABLE": TokenType.STABLE, - "STORED": TokenType.STORED, "TABLE": TokenType.TABLE, - "TABLE_FORMAT": TokenType.TABLE_FORMAT, - "TBLPROPERTIES": TokenType.PROPERTIES, "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY, diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 35ff75a..aa7d240 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -27,20 +27,18 @@ def unalias_group(expression: exp.Expression) -> exp.Expression: """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { - e.alias: (i, e.this) + e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias) } - expression = expression.copy() - - top_level_expression = None - for item, parent, _ in expression.walk(bfs=False): - top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression - if isinstance(item, exp.Column) and not item.table: - alias_index, col_expression = aliased_selects.get(item.name, (None, None)) - if alias_index and top_level_expression != col_expression: - item.replace(exp.Literal.number(alias_index)) + for group_by in expression.expressions: + if ( + isinstance(group_by, exp.Column) + and not group_by.table + and group_by.name in aliased_selects + ): + group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) return expression @@ -63,22 +61,21 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: and expression.args["distinct"].args.get("on") and isinstance(expression.args["distinct"].args["on"], exp.Tuple) ): - distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions] - outer_selects = [e.copy() for e in expression.expressions] - nested = expression.copy() - nested.args["distinct"].pop() + distinct_cols = expression.args["distinct"].args["on"].expressions + expression.args["distinct"].pop() + outer_selects = expression.selects row_number = find_new_name(expression.named_selects, "_row_number") window = exp.Window( this=exp.RowNumber(), partition_by=distinct_cols, ) - order = nested.args.get("order") + order = expression.args.get("order") if order: window.set("order", order.copy()) order.pop() window = exp.alias_(window, row_number) - nested.select(window, copy=False) - return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1') + expression.select(window, copy=False) + return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') return expression @@ -120,7 +117,7 @@ def preprocess( """ def _to_sql(self, expression): - expression = transforms[0](expression) + expression = transforms[0](expression.copy()) for t in transforms[1:]: expression = t(expression) return to_sql(self, expression) -- cgit v1.2.3