diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:35 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-19 14:50:35 +0000 |
commit | 2272764864555f26095563937e06a3389d42d789 (patch) | |
tree | 9dc37b7bff42ec0343028e5ecfb0aa147c5d3279 /sqlglot/parser.py | |
parent | Adding upstream version 10.0.1. (diff) | |
download | sqlglot-2272764864555f26095563937e06a3389d42d789.tar.xz sqlglot-2272764864555f26095563937e06a3389d42d789.zip |
Adding upstream version 10.0.8.upstream/10.0.8
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/parser.py')
-rw-r--r-- | sqlglot/parser.py | 403 |
1 files changed, 197 insertions, 206 deletions
diff --git a/sqlglot/parser.py b/sqlglot/parser.py index bbea0e5..5b93510 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -141,26 +141,29 @@ class Parser(metaclass=_Parser): ID_VAR_TOKENS = { TokenType.VAR, - TokenType.ALTER, TokenType.ALWAYS, TokenType.ANTI, TokenType.APPLY, + TokenType.AUTO_INCREMENT, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, TokenType.CACHE, - TokenType.CALL, + TokenType.CASCADE, TokenType.COLLATE, + TokenType.COMMAND, TokenType.COMMIT, TokenType.CONSTRAINT, + TokenType.CURRENT_TIME, TokenType.DEFAULT, TokenType.DELETE, TokenType.DESCRIBE, TokenType.DETERMINISTIC, + TokenType.DISTKEY, + TokenType.DISTSTYLE, TokenType.EXECUTE, TokenType.ENGINE, TokenType.ESCAPE, - TokenType.EXPLAIN, TokenType.FALSE, TokenType.FIRST, TokenType.FOLLOWING, @@ -182,7 +185,6 @@ class Parser(metaclass=_Parser): TokenType.NATURAL, TokenType.NEXT, TokenType.ONLY, - TokenType.OPTIMIZE, TokenType.OPTIONS, TokenType.ORDINALITY, TokenType.PARTITIONED_BY, @@ -199,6 +201,7 @@ class Parser(metaclass=_Parser): TokenType.SEMI, TokenType.SET, TokenType.SHOW, + TokenType.SORTKEY, TokenType.STABLE, TokenType.STORED, TokenType.TABLE, @@ -207,7 +210,6 @@ class Parser(metaclass=_Parser): TokenType.TRANSIENT, TokenType.TOP, TokenType.TRAILING, - TokenType.TRUNCATE, TokenType.TRUE, TokenType.UNBOUNDED, TokenType.UNIQUE, @@ -217,6 +219,7 @@ class Parser(metaclass=_Parser): TokenType.VOLATILE, *SUBQUERY_PREDICATES, *TYPE_TOKENS, + *NO_PAREN_FUNCTIONS, } TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} @@ -231,6 +234,7 @@ class Parser(metaclass=_Parser): TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, + TokenType.IDENTIFIER, TokenType.ISNULL, TokenType.OFFSET, TokenType.PRIMARY_KEY, @@ -242,6 +246,7 @@ class Parser(metaclass=_Parser): TokenType.RIGHT, TokenType.DATE, TokenType.DATETIME, + TokenType.TABLE, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, *TYPE_TOKENS, @@ -277,6 +282,7 @@ class Parser(metaclass=_Parser): TokenType.DASH: exp.Sub, TokenType.PLUS: exp.Add, TokenType.MOD: exp.Mod, + TokenType.COLLATE: exp.Collate, } FACTOR = { @@ -391,7 +397,10 @@ class Parser(metaclass=_Parser): TokenType.DELETE: lambda self: self._parse_delete(), TokenType.CACHE: lambda self: self._parse_cache(), TokenType.UNCACHE: lambda self: self._parse_uncache(), - TokenType.USE: lambda self: self._parse_use(), + TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), } PRIMARY_PARSERS = { @@ -402,7 +411,8 @@ class Parser(metaclass=_Parser): exp.Literal, this=token.text, is_string=False ), TokenType.STAR: lambda self, _: self.expression( - exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + exp.Star, + **{"except": self._parse_except(), "replace": self._parse_replace()}, ), TokenType.NULL: lambda self, _: self.expression(exp.Null), TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), @@ -446,6 +456,9 @@ class Parser(metaclass=_Parser): TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(), TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), TokenType.STORED: lambda self: self._parse_stored(), + TokenType.DISTKEY: lambda self: self._parse_distkey(), + TokenType.DISTSTYLE: lambda self: self._parse_diststyle(), + TokenType.SORTKEY: lambda self: self._parse_sortkey(), TokenType.RETURNS: lambda self: self._parse_returns(), TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty), TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), @@ -471,7 +484,9 @@ class Parser(metaclass=_Parser): } CONSTRAINT_PARSERS = { - TokenType.CHECK: lambda self: self._parse_check(), + TokenType.CHECK: lambda self: self.expression( + exp.Check, this=self._parse_wrapped(self._parse_conjunction) + ), TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), TokenType.UNIQUE: lambda self: self._parse_unique(), } @@ -521,6 +536,8 @@ class Parser(metaclass=_Parser): TokenType.SCHEMA, } + TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + STRICT_CAST = True __slots__ = ( @@ -740,6 +757,7 @@ class Parser(metaclass=_Parser): kind=kind, temporary=temporary, materialized=materialized, + cascade=self._match(TokenType.CASCADE), ) def _parse_exists(self, not_=False): @@ -777,7 +795,11 @@ class Parser(metaclass=_Parser): expression = self._parse_select_or_expression() elif create_token.token_type == TokenType.INDEX: this = self._parse_index() - elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA): + elif create_token.token_type in ( + TokenType.TABLE, + TokenType.VIEW, + TokenType.SCHEMA, + ): this = self._parse_table(schema=True) properties = self._parse_properties() if self._match(TokenType.ALIAS): @@ -834,7 +856,38 @@ class Parser(metaclass=_Parser): return self.expression( exp.FileFormatProperty, this=exp.Literal.string("FORMAT"), - value=exp.Literal.string(self._parse_var().name), + value=exp.Literal.string(self._parse_var_or_string().name), + ) + + def _parse_distkey(self): + self._match_l_paren() + this = exp.Literal.string("DISTKEY") + value = exp.Literal.string(self._parse_var().name) + self._match_r_paren() + return self.expression( + exp.DistKeyProperty, + this=this, + value=value, + ) + + def _parse_sortkey(self): + self._match_l_paren() + this = exp.Literal.string("SORTKEY") + value = exp.Literal.string(self._parse_var().name) + self._match_r_paren() + return self.expression( + exp.SortKeyProperty, + this=this, + value=value, + ) + + def _parse_diststyle(self): + this = exp.Literal.string("DISTSTYLE") + value = exp.Literal.string(self._parse_var().name) + return self.expression( + exp.DistStyleProperty, + this=this, + value=value, ) def _parse_auto_increment(self): @@ -842,7 +895,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.AutoIncrementProperty, this=exp.Literal.string("AUTO_INCREMENT"), - value=self._parse_var() or self._parse_number(), + value=self._parse_number(), ) def _parse_schema_comment(self): @@ -898,13 +951,10 @@ class Parser(metaclass=_Parser): while True: if self._match(TokenType.WITH): - self._match_l_paren() - properties.extend(self._parse_csv(lambda: self._parse_property())) - self._match_r_paren() + properties.extend(self._parse_wrapped_csv(self._parse_property)) elif self._match(TokenType.PROPERTIES): - self._match_l_paren() properties.extend( - self._parse_csv( + self._parse_wrapped_csv( lambda: self.expression( exp.AnonymousProperty, this=self._parse_string(), @@ -912,25 +962,24 @@ class Parser(metaclass=_Parser): ) ) ) - self._match_r_paren() else: identified_property = self._parse_property() if not identified_property: break properties.append(identified_property) + if properties: return self.expression(exp.Properties, expressions=properties) return None def _parse_describe(self): self._match(TokenType.TABLE) - return self.expression(exp.Describe, this=self._parse_id_var()) def _parse_insert(self): overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) - if self._match_text("DIRECTORY"): + if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, this=self._parse_var_or_string(), @@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser): if not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None - self._match_text("DELIMITED") + self._match_text_seq("DELIMITED") kwargs = {} - if self._match_text("FIELDS", "TERMINATED", "BY"): + if self._match_text_seq("FIELDS", "TERMINATED", "BY"): kwargs["fields"] = self._parse_string() - if self._match_text("ESCAPED", "BY"): + if self._match_text_seq("ESCAPED", "BY"): kwargs["escaped"] = self._parse_string() - if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"): + if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"): kwargs["collection_items"] = self._parse_string() - if self._match_text("MAP", "KEYS", "TERMINATED", "BY"): + if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"): kwargs["map_keys"] = self._parse_string() - if self._match_text("LINES", "TERMINATED", "BY"): + if self._match_text_seq("LINES", "TERMINATED", "BY"): kwargs["lines"] = self._parse_string() - if self._match_text("NULL", "DEFINED", "AS"): + if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() return self.expression(exp.RowFormat, **kwargs) def _parse_load_data(self): local = self._match(TokenType.LOCAL) - self._match_text("INPATH") + self._match_text_seq("INPATH") inpath = self._parse_string() overwrite = self._match(TokenType.OVERWRITE) self._match_pair(TokenType.INTO, TokenType.TABLE) @@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser): overwrite=overwrite, inpath=inpath, partition=self._parse_partition(), - input_format=self._match_text("INPUTFORMAT") and self._parse_string(), - serde=self._match_text("SERDE") and self._parse_string(), + input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), + serde=self._match_text_seq("SERDE") and self._parse_string(), ) def _parse_delete(self): @@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Delete, this=self._parse_table(schema=True), - using=self._parse_csv( - lambda: self._match(TokenType.USING) and self._parse_table(schema=True) - ), + using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), where=self._parse_where(), ) @@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser): options = [] if self._match(TokenType.OPTIONS): - self._match_l_paren() - k = self._parse_string() - self._match(TokenType.EQ) - v = self._parse_string() - options = [k, v] - self._match_r_paren() + options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ) self._match(TokenType.ALIAS) return self.expression( @@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser): return None def parse_values(): - key = self._parse_var() - value = None - - if self._match(TokenType.EQ): - value = self._parse_string() - - return exp.Property(this=key, value=value) - - self._match_l_paren() - values = self._parse_csv(parse_values) - self._match_r_paren() + props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ) + return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1)) - return self.expression( - exp.Partition, - this=values, - ) + return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) def _parse_value(self): - self._match_l_paren() - expressions = self._parse_csv(self._parse_conjunction) - self._match_r_paren() + expressions = self._parse_wrapped_csv(self._parse_conjunction) return self.expression(exp.Tuple, expressions=expressions) def _parse_select(self, nested=False, table=False): @@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser): self._match_r_paren() this = self._parse_subquery(this) elif self._match(TokenType.VALUES): - this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value)) - alias = self._parse_table_alias() - if alias: - this = self.expression(exp.Subquery, this=this, alias=alias) + this = self.expression( + exp.Values, + expressions=self._parse_csv(self._parse_value), + alias=self._parse_table_alias(), + ) else: this = None @@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser): recursive = self._match(TokenType.RECURSIVE) expressions = [] - while True: expressions.append(self._parse_cte()) @@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser): else: self._match(TokenType.WITH) - return self.expression( - exp.With, - expressions=expressions, - recursive=recursive, - ) + return self.expression(exp.With, expressions=expressions, recursive=recursive) def _parse_cte(self): alias = self._parse_table_alias() @@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.ALIAS): self.raise_error("Expected AS in CTE") - self._match_l_paren() - expression = self._parse_statement() - self._match_r_paren() - return self.expression( exp.CTE, - this=expression, + this=self._parse_wrapped(self._parse_statement), alias=alias, ) @@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser): def _parse_hint(self): if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) - if not self._match(TokenType.HINT): + if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") return self.expression(exp.Hint, expressions=hints) return None @@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser): columns = self._parse_csv(self._parse_id_var) elif self._match(TokenType.L_PAREN): columns = self._parse_csv(self._parse_id_var) - self._match(TokenType.R_PAREN) + self._match_r_paren() expression = self.expression( exp.Lateral, this=this, view=view, outer=outer, - alias=self.expression( - exp.TableAlias, - this=table_alias, - columns=columns, - ), + alias=self.expression(exp.TableAlias, this=table_alias, columns=columns), ) if outer_apply or cross_apply: - return self.expression( - exp.Join, - this=expression, - side=None if cross_apply else "LEFT", - ) + return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT") return expression @@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.UNNEST): return None - self._match_l_paren() - expressions = self._parse_csv(self._parse_column) - self._match_r_paren() - + expressions = self._parse_wrapped_csv(self._parse_column) ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) - alias = self._parse_table_alias() if alias and self.unnest_column_only: @@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser): alias.set("this", None) return self.expression( - exp.Unnest, - expressions=expressions, - ordinality=ordinality, - alias=alias, + exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias ) def _parse_derived_table_values(self): @@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser): if is_derived: self._match_r_paren() - alias = self._parse_table_alias() - - return self.expression( - exp.Values, - expressions=expressions, - alias=alias, - ) + return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) def _parse_table_sample(self): if not self._match(TokenType.TABLE_SAMPLE): @@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() if self._match(TokenType.SEED): - self._match_l_paren() - seed = self._parse_number() - self._match_r_paren() + seed = self._parse_wrapped(self._parse_number) return self.expression( exp.TableSample, @@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser): self._match_r_paren() - return self.expression( - exp.Pivot, - expressions=expressions, - field=field, - unpivot=unpivot, - ) + return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) def _parse_where(self, skip_where_token=False): if not skip_where_token and not self._match(TokenType.WHERE): @@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser): def _parse_grouping_sets(self): if not self._match(TokenType.GROUPING_SETS): return None - - self._match_l_paren() - grouping_sets = self._parse_csv(self._parse_grouping_set) - self._match_r_paren() - return grouping_sets + return self._parse_wrapped_csv(self._parse_grouping_set) def _parse_grouping_set(self): if self._match(TokenType.L_PAREN): @@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser): def _parse_sort(self, token_type, exp_class): if not self._match(token_type): return None - return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) def _parse_ordered(self): @@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser): if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) + if limit_paren: - self._match(TokenType.R_PAREN) + self._match_r_paren() + return limit_exp + if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) self._match(TokenType.ONLY) return self.expression(exp.Fetch, direction=direction, count=count) + return this def _parse_offset(self, this=None): if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): return this + count = self._parse_number() self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) @@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) + this = self.expression( exp.Is, this=this, @@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser): def _parse_type(self): if self._match(TokenType.INTERVAL): - return self.expression( - exp.Interval, - this=self._parse_term(), - unit=self._parse_var(), - ) + return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var()) index = self._index type_token = self._parse_types(check_func=True) @@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser): value = None if type_token in self.TIMESTAMPS: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMPTZ, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) elif ( self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ ): - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMPLTZ, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match(TokenType.WITHOUT_TIME_ZONE): - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) maybe_func = maybe_func and value is None if value is None: - value = exp.DataType( - this=exp.DataType.Type.TIMESTAMP, - expressions=expressions, - ) + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) if maybe_func and check_func: index2 = self._index @@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser): this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() + if not data_type: return None return self.expression(exp.StructKwarg, this=this, expression=data_type) @@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser): def _parse_at_time_zone(self, this): if not self._match(TokenType.AT_TIME_ZONE): return this - return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) def _parse_column(self): @@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser): else: subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - if subquery_predicate and self._curr.token_type in ( - TokenType.SELECT, - TokenType.WITH, - ): + if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH): this = self.expression(subquery_predicate, this=self._parse_select()) self._match_r_paren() return this if functions is None: functions = self.FUNCTIONS + function = functions.get(upper) args = self._parse_csv(self._parse_lambda) @@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return this + expressions = self._parse_csv(self._parse_udf_kwarg) self._match_r_paren() return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) @@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser): def _parse_introducer(self, token): literal = self._parse_primary() if literal: - return self.expression( - exp.Introducer, - this=token.text, - expression=literal, - ) + return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) def _parse_session_parameter(self): kind = None this = self._parse_id_var() or self._parse_primary() + if self._match(TokenType.DOT): kind = this.name this = self._parse_var() or self._parse_primary() - return self.expression( - exp.SessionParameter, - this=this, - kind=kind, - ) + + return self.expression(exp.SessionParameter, this=this, kind=kind) def _parse_udf_kwarg(self): this = self._parse_id_var() @@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) def _parse_column_constraint(self): - this = None + this = self._parse_references() + + if this: + return this if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() @@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser): if self._match(TokenType.AUTO_INCREMENT): kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): - self._match_l_paren() - kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction()) - self._match_r_paren() + constraint = self._parse_wrapped(self._parse_conjunction) + kind = self.expression(exp.CheckColumnConstraint, this=constraint) elif self._match(TokenType.COLLATE): kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field()) + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.SCHEMA_COMMENT): @@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser): kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) self._match_pair(TokenType.ALIAS, TokenType.IDENTITY) else: - return None + return this return self.expression(exp.ColumnConstraint, this=this, kind=kind) @@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser): def _parse_unnamed_constraint(self): if not self._match_set(self.CONSTRAINT_PARSERS): return None - return self.CONSTRAINT_PARSERS[self._prev.token_type](self) - def _parse_check(self): - self._match(TokenType.CHECK) - self._match_l_paren() - expression = self._parse_conjunction() - self._match_r_paren() - - return self.expression(exp.Check, this=expression) - def _parse_unique(self): - self._match(TokenType.UNIQUE) - columns = self._parse_wrapped_id_vars() - - return self.expression(exp.Unique, expressions=columns) + return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) - def _parse_foreign_key(self): - self._match(TokenType.FOREIGN_KEY) - - expressions = self._parse_wrapped_id_vars() - reference = self._match(TokenType.REFERENCES) and self.expression( + def _parse_references(self): + if not self._match(TokenType.REFERENCES): + return None + return self.expression( exp.Reference, this=self._parse_id_var(), expressions=self._parse_wrapped_id_vars(), ) + + def _parse_foreign_key(self): + expressions = self._parse_wrapped_id_vars() + reference = self._parse_references() options = {} while self._match(TokenType.ON): if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): self.raise_error("Expected DELETE or UPDATE") + kind = self._prev.text.lower() if self._match(TokenType.NO_ACTION): @@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser): else: self._advance() action = self._prev.text.upper() + options[kind] = action return self.expression( @@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser): def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): - self._match_l_paren() - this = self.expression(exp.Filter, this=this, expression=self._parse_where()) - self._match_r_paren() + where = self._parse_wrapped(self._parse_where) + this = self.expression(exp.Filter, this=this, expression=where) # T-SQL allows the OVER (...) syntax after WITHIN GROUP. # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 if self._match(TokenType.WITHIN_GROUP): - self._match_l_paren() - this = self.expression( - exp.WithinGroup, - this=this, - expression=self._parse_order(), - ) - self._match_r_paren() + order = self._parse_wrapped(self._parse_order) + this = self.expression(exp.WithinGroup, this=this, expression=order) # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER # Some dialects choose to implement and some do not. @@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser): return this if not self._match(TokenType.L_PAREN): - alias = self._parse_id_var(False) - - return self.expression( - exp.Window, - this=this, - alias=alias, - ) - - partition = None + return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) alias = self._parse_id_var(False) + partition = None if self._match(TokenType.PARTITION_BY): partition = self._parse_csv(self._parse_conjunction) @@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser): def _parse_replace(self): if not self._match(TokenType.REPLACE): return None + return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) - self._match_l_paren() - columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression())) - self._match_r_paren() - return columns - - def _parse_csv(self, parse_method): + def _parse_csv(self, parse_method, sep=TokenType.COMMA): parse_result = parse_method() items = [parse_result] if parse_result is not None else [] - while self._match(TokenType.COMMA): + while self._match(sep): if parse_result and self._prev_comment is not None: parse_result.comment = self._prev_comment @@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser): return this def _parse_wrapped_id_vars(self): + return self._parse_wrapped_csv(self._parse_id_var) + + def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA): + return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) + + def _parse_wrapped(self, parse_method): self._match_l_paren() - expressions = self._parse_csv(self._parse_id_var) + parse_result = parse_method() self._match_r_paren() - return expressions + return parse_result def _parse_select_or_expression(self): return self._parse_select() or self._parse_expression() - def _parse_use(self): - return self.expression(exp.Use, this=self._parse_id_var()) + def _parse_transaction(self): + this = None + if self._match_texts(self.TRANSACTION_KIND): + this = self._prev.text + + self._match_texts({"TRANSACTION", "WORK"}) + + modes = [] + while True: + mode = [] + while self._match(TokenType.VAR): + mode.append(self._prev.text) + + if mode: + modes.append(" ".join(mode)) + if not self._match(TokenType.COMMA): + break + + return self.expression(exp.Transaction, this=this, modes=modes) + + def _parse_commit_or_rollback(self): + savepoint = None + is_rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts({"TRANSACTION", "WORK"}) + + if self._match_text_seq("TO"): + self._match_text_seq("SAVEPOINT") + savepoint = self._parse_id_var() + + if is_rollback: + return self.expression(exp.Rollback, savepoint=savepoint) + return self.expression(exp.Commit) def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) @@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser): if expression and self._prev_comment: expression.comment = self._prev_comment - def _match_text(self, *texts): + def _match_texts(self, texts): + if self._curr and self._curr.text.upper() in texts: + self._advance() + return True + return False + + def _match_text_seq(self, *texts): index = self._index for text in texts: if self._curr and self._curr.text.upper() == text: |