From 5b1ac5070c43c40a2b5bbc991198b0dddf45dc75 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 8 Mar 2023 08:22:15 +0100 Subject: Merging upstream version 11.3.3. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dialects/bigquery.py | 2 + sqlglot/dialects/dialect.py | 5 ++ sqlglot/dialects/duckdb.py | 5 +- sqlglot/dialects/hive.py | 2 + sqlglot/dialects/mysql.py | 4 +- sqlglot/dialects/postgres.py | 3 + sqlglot/dialects/redshift.py | 1 + sqlglot/dialects/snowflake.py | 22 +++++-- sqlglot/dialects/sqlite.py | 14 +++-- sqlglot/dialects/teradata.py | 7 ++- sqlglot/dialects/tsql.py | 8 ++- sqlglot/expressions.py | 32 ++++++++-- sqlglot/generator.py | 33 +++++++--- sqlglot/optimizer/annotate_types.py | 9 ++- sqlglot/parser.py | 118 ++++++++++++++++++++++++++---------- sqlglot/tokens.py | 8 ++- 17 files changed, 209 insertions(+), 66 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index d026627..a9a220c 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -47,7 +47,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.3.0" +__version__ = "11.3.3" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index a3869c6..ccdd1c9 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, inline_array_sql, + min_or_least, no_ilike_sql, rename_func, timestrtotime_sql, @@ -232,6 +233,7 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Min: min_or_least, exp.Select: transforms.preprocess( [_unqualify_unnest], transforms.delegate("select_sql") ), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 6939705..25490cb 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -407,6 +407,11 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" +def min_or_least(self: Generator, expression: exp.Min) -> str: + name = "LEAST" if expression.expressions else "MIN" + return rename_func(name)(self, expression) + + def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index c2755cd..db79d86 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -11,7 +11,6 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_properties_sql, no_safe_divide_sql, - no_tablesample_sql, rename_func, str_position_sql, str_to_time_sql, @@ -155,7 +154,6 @@ class DuckDB(Dialect): exp.StrToTime: str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, - exp.TableSample: no_tablesample_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", @@ -179,3 +177,6 @@ class DuckDB(Dialect): **generator.Generator.STAR_MAPPING, "except": "EXCLUDE", } + + def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: + return super().tablesample_sql(expression, seed_prefix="REPEATABLE") diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 44cd875..faed1cf 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, locate_to_strposition, + min_or_least, no_ilike_sql, no_recursive_cte_sql, no_safe_divide_sql, @@ -291,6 +292,7 @@ class Hive(Dialect): exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: var_map_sql, + exp.Min: min_or_least, exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index b1e20bd..3531f59 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -4,6 +4,7 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, locate_to_strposition, + min_or_least, no_ilike_sql, no_paren_current_date_sql, no_tablesample_sql, @@ -179,7 +180,7 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore @@ -441,6 +442,7 @@ class MySQL(Dialect): exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, + exp.Min: min_or_least, exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, exp.DateAdd: _date_add_sql("ADD"), diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 35076db..7e7902c 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, format_time_lambda, + min_or_least, no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, @@ -229,6 +230,7 @@ class Postgres(Dialect): "REFRESH": TokenType.COMMAND, "REINDEX": TokenType.COMMAND, "RESET": TokenType.COMMAND, + "RETURNING": TokenType.RETURNING, "REVOKE": TokenType.COMMAND, "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, @@ -296,6 +298,7 @@ class Postgres(Dialect): exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), + exp.Min: min_or_least, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index b4268e6..22ef51c 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -53,6 +53,7 @@ class Redshift(Postgres): "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, + "TOP": TokenType.TOP, "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 4a090c2..6413f6d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + min_or_least, rename_func, timestrtotime_sql, ts_or_ds_to_date_sql, @@ -116,10 +117,16 @@ def _div0_to_if(args): # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull def _zeroifnull_to_if(args): - cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null()) + cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) +# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull +def _nullifzero_to_if(args): + cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) + return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) + + def _datatype_sql(self, expression): if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" @@ -167,6 +174,11 @@ class Snowflake(Dialect): **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), @@ -180,6 +192,7 @@ class Snowflake(Dialect): "DECODE": exp.Matches.from_arg_list, "OBJECT_CONSTRUCT": parser.parse_var_map, "ZEROIFNULL": _zeroifnull_to_if, + "NULLIFZERO": _nullifzero_to_if, } FUNCTION_PARSERS = { @@ -254,6 +267,7 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" INTEGER_DIVISION = False + MATCHED_BY_SOURCE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -278,6 +292,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Min: min_or_least, } TYPE_MAPPING = { @@ -343,11 +358,10 @@ class Snowflake(Dialect): expression. This might not be true in a case where the same column name can be sourced from another table that can properly quote but should be true in most cases. """ - values_expressions = expression.find_all(exp.Values) values_identifiers = set( flatten( - v.args.get("alias", exp.Alias()).args.get("columns", []) - for v in values_expressions + (v.args.get("alias") or exp.Alias()).args.get("columns", []) + for v in expression.find_all(exp.Values) ) ) if values_identifiers: diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 86603b5..fb99d49 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -13,10 +13,6 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -def _fetch_sql(self, expression): - return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) - - # https://www.sqlite.org/lang_aggfunc.html#group_concat def _group_concat_sql(self, expression): this = expression.this @@ -94,9 +90,17 @@ class SQLite(Dialect): exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, exp.GroupConcat: _group_concat_sql, - exp.Fetch: _fetch_sql, } + def fetch_sql(self, expression): + return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) + + def least_sql(self, expression): + if len(expression.expressions) > 1: + return rename_func("MIN")(self, expression) + + return self.expressions(expression) + def transaction_sql(self, expression): this = expression.this this = f" {this}" if this else "" diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 415681c..06d5c6c 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.dialect import Dialect, min_or_least from sqlglot.tokens import TokenType @@ -126,6 +126,11 @@ class Teradata(Dialect): exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, } + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.Min: min_or_least, + } + def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: return f"PARTITION BY {self.sql(expression, 'this')}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b9f932b..d07b083 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -4,7 +4,12 @@ import re import typing as t from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func +from sqlglot.dialects.dialect import ( + Dialect, + min_or_least, + parse_date_delta, + rename_func, +) from sqlglot.expressions import DataType from sqlglot.helper import seq_get from sqlglot.time import format_time @@ -433,6 +438,7 @@ class TSQL(Dialect): exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, + exp.Min: min_or_least, } TRANSFORMS.pop(exp.ReturnsProperty) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 00a3b45..085871e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1031,7 +1031,7 @@ class Constraint(Expression): class Delete(Expression): - arg_types = {"with": False, "this": False, "using": False, "where": False} + arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False} class Drop(Expression): @@ -1132,6 +1132,7 @@ class Insert(Expression): "with": False, "this": True, "expression": False, + "returning": False, "overwrite": False, "exists": False, "partition": False, @@ -1139,6 +1140,10 @@ class Insert(Expression): } +class Returning(Expression): + arg_types = {"expressions": True} + + # https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html class Introducer(Expression): arg_types = {"this": True, "expression": True} @@ -1747,6 +1752,7 @@ QUERY_MODIFIERS = { "limit": False, "offset": False, "lock": False, + "sample": False, } @@ -1895,6 +1901,7 @@ class Update(Expression): "expressions": True, "from": False, "where": False, + "returning": False, } @@ -2401,6 +2408,18 @@ class Select(Subqueryable): **opts, ) + def qualify(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="qualify", + append=append, + into=Qualify, + dialect=dialect, + copy=copy, + **opts, + ) + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -2531,6 +2550,7 @@ class TableSample(Expression): "rows": False, "size": False, "seed": False, + "kind": False, } @@ -3423,7 +3443,7 @@ class JSONBExtractScalar(JSONExtract): class Least(Func): - arg_types = {"this": True, "expressions": False} + arg_types = {"expressions": False} is_var_len_args = True @@ -3485,11 +3505,13 @@ class Matches(Func): class Max(AggFunc): - arg_types = {"this": True, "expression": False} + arg_types = {"this": True, "expressions": False} + is_var_len_args = True class Min(AggFunc): - arg_types = {"this": True, "expression": False} + arg_types = {"this": True, "expressions": False} + is_var_len_args = True class Month(Func): @@ -3764,7 +3786,7 @@ class Merge(Expression): class When(Func): - arg_types = {"this": True, "then": True} + arg_types = {"matched": True, "source": False, "condition": False, "then": True} def _norm_args(expression): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 79501ef..4504e95 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -112,6 +112,9 @@ class Generator: # Whether or not to treat the division operator "/" as integer division INTEGER_DIVISION = True + # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed + MATCHED_BY_SOURCE = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -688,7 +691,8 @@ class Generator: else "" ) where_sql = self.sql(expression, "where") - sql = f"DELETE{this}{using_sql}{where_sql}" + returning = self.sql(expression, "returning") + sql = f"DELETE{this}{using_sql}{where_sql}{returning}" return self.prepend_ctes(expression, sql) def drop_sql(self, expression: exp.Drop) -> str: @@ -952,8 +956,9 @@ class Generator: self.sql(expression, "partition") if expression.args.get("partition") else "" ) expression_sql = self.sql(expression, "expression") + returning = self.sql(expression, "returning") sep = self.sep() if partition_sql else "" - sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}" + sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -971,6 +976,9 @@ class Generator: def pseudotype_sql(self, expression: exp.PseudoType) -> str: return expression.name.upper() + def returning_sql(self, expression: exp.Returning) -> str: + return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}" + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" @@ -1009,7 +1017,7 @@ class Generator: return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" - def tablesample_sql(self, expression: exp.TableSample) -> str: + def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: if self.alias_post_tablesample and expression.this.alias: this = self.sql(expression.this, "this") alias = f" AS {self.sql(expression.this, 'alias')}" @@ -1017,7 +1025,7 @@ class Generator: this = self.sql(expression, "this") alias = "" method = self.sql(expression, "method") - method = f" {method.upper()} " if method else "" + method = f"{method.upper()} " if method else "" numerator = self.sql(expression, "bucket_numerator") denominator = self.sql(expression, "bucket_denominator") field = self.sql(expression, "bucket_field") @@ -1029,8 +1037,9 @@ class Generator: rows = f"{rows} ROWS" if rows else "" size = self.sql(expression, "size") seed = self.sql(expression, "seed") - seed = f" SEED ({seed})" if seed else "" - return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}" + seed = f" {seed_prefix} ({seed})" if seed else "" + kind = expression.args.get("kind", "TABLESAMPLE") + return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}" def pivot_sql(self, expression: exp.Pivot) -> str: this = self.sql(expression, "this") @@ -1050,7 +1059,8 @@ class Generator: set_sql = self.expressions(expression, flat=True) from_sql = self.sql(expression, "from") where_sql = self.sql(expression, "where") - sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}" + returning = self.sql(expression, "returning") + sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}{returning}" return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: @@ -1297,6 +1307,7 @@ class Generator: self.sql(expression, "limit"), self.sql(expression, "offset"), self.sql(expression, "lock"), + self.sql(expression, "sample"), sep="", ) @@ -1956,7 +1967,11 @@ class Generator: return self.binary(expression, "=>") def when_sql(self, expression: exp.When) -> str: - this = self.sql(expression, "this") + matched = "MATCHED" if expression.args["matched"] else "NOT MATCHED" + source = " BY SOURCE" if self.MATCHED_BY_SOURCE and expression.args.get("source") else "" + condition = self.sql(expression, "condition") + condition = f" AND {condition}" if condition else "" + then_expression = expression.args.get("then") if isinstance(then_expression, exp.Insert): then = f"INSERT {self.sql(then_expression, 'this')}" @@ -1969,7 +1984,7 @@ class Generator: then = f"UPDATE SET {self.expressions(then_expression, flat=True)}" else: then = self.sql(then_expression) - return f"WHEN {this} THEN {then}" + return f"WHEN {matched}{source}{condition} THEN {then}" def merge_sql(self, expression: exp.Merge) -> str: this = self.sql(expression, "this") diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index ca2131c..c2d6655 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -55,9 +55,11 @@ class TypeAnnotator: expr, exp.DataType.Type.BIGINT ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), - exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), - exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), + exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), + exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), + exp.Sum: lambda self, expr: self._annotate_by_args( + expr, "this", "expressions", promote=True + ), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), @@ -114,6 +116,7 @@ class TypeAnnotator: expr, exp.DataType.Type.VARCHAR ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f39bb39..894d68e 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -434,6 +434,7 @@ class Parser(metaclass=_Parser): exp.Having: lambda self: self._parse_having(), exp.With: lambda self: self._parse_with(), exp.Window: lambda self: self._parse_named_window(), + exp.Qualify: lambda self: self._parse_qualify(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -688,6 +689,7 @@ class Parser(metaclass=_Parser): "limit": lambda self: self._parse_limit(), "offset": lambda self: self._parse_offset(), "lock": lambda self: self._parse_lock(), + "sample": lambda self: self._parse_table_sample(as_modifier=True), } SHOW_PARSERS: t.Dict[str, t.Callable] = {} @@ -953,7 +955,8 @@ class Parser(metaclass=_Parser): self._prev_comments = None def _retreat(self, index: int) -> None: - self._advance(index - self._index) + if index != self._index: + self._advance(index - self._index) def _parse_command(self) -> exp.Expression: return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) @@ -1515,12 +1518,10 @@ class Parser(metaclass=_Parser): def _parse_insert(self) -> exp.Expression: overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) - - this: t.Optional[exp.Expression] - alternative = None + if self._match_text_seq("DIRECTORY"): - this = self.expression( + this: t.Optional[exp.Expression] = self.expression( exp.Directory, this=self._parse_var_or_string(), local=local, @@ -1540,10 +1541,17 @@ class Parser(metaclass=_Parser): exists=self._parse_exists(), partition=self._parse_partition(), expression=self._parse_ddl_select(), + returning=self._parse_returning(), overwrite=overwrite, alternative=alternative, ) + def _parse_returning(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.RETURNING): + return None + + return self.expression(exp.Returning, expressions=self._parse_csv(self._parse_column)) + def _parse_row(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FORMAT): return None @@ -1601,6 +1609,7 @@ class Parser(metaclass=_Parser): this=self._parse_table(schema=True), using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), where=self._parse_where(), + returning=self._parse_returning(), ) def _parse_update(self) -> exp.Expression: @@ -1611,6 +1620,7 @@ class Parser(metaclass=_Parser): "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), "where": self._parse_where(), + "returning": self._parse_returning(), }, ) @@ -2156,11 +2166,12 @@ class Parser(metaclass=_Parser): return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) - def _parse_table_sample(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.TABLE_SAMPLE): + def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.Expression]: + if not self._match(TokenType.TABLE_SAMPLE) and not ( + as_modifier and self._match_text_seq("USING", "SAMPLE") + ): return None - method = self._parse_var() bucket_numerator = None bucket_denominator = None bucket_field = None @@ -2169,7 +2180,12 @@ class Parser(metaclass=_Parser): size = None seed = None - self._match_l_paren() + kind = "TABLESAMPLE" if self._prev.token_type == TokenType.TABLE_SAMPLE else "USING SAMPLE" + method = self._parse_var(tokens=(TokenType.ROW,)) + + self._match(TokenType.L_PAREN) + + num = self._parse_number() if self._match(TokenType.BUCKET): bucket_numerator = self._parse_number() @@ -2177,19 +2193,20 @@ class Parser(metaclass=_Parser): bucket_denominator = bucket_denominator = self._parse_number() self._match(TokenType.ON) bucket_field = self._parse_field() + elif self._match_set((TokenType.PERCENT, TokenType.MOD)): + percent = num + elif self._match(TokenType.ROWS): + rows = num else: - num = self._parse_number() + size = num - if self._match(TokenType.PERCENT): - percent = num - elif self._match(TokenType.ROWS): - rows = num - else: - size = num + self._match(TokenType.R_PAREN) - self._match_r_paren() - - if self._match(TokenType.SEED): + if self._match(TokenType.L_PAREN): + method = self._parse_var() + seed = self._match(TokenType.COMMA) and self._parse_number() + self._match_r_paren() + elif self._match_texts(("SEED", "REPEATABLE")): seed = self._parse_wrapped(self._parse_number) return self.expression( @@ -2202,6 +2219,7 @@ class Parser(metaclass=_Parser): rows=rows, size=size, seed=seed, + kind=kind, ) def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: @@ -2531,7 +2549,7 @@ class Parser(metaclass=_Parser): this = self._parse_column() if type_token: - if this and not isinstance(this, exp.Star): + if isinstance(this, exp.Literal): return self.expression(exp.Cast, this=this, to=type_token) if not type_token.args.get("expressions"): self._retreat(index) @@ -2626,7 +2644,12 @@ class Parser(metaclass=_Parser): if value is None: value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) elif type_token == TokenType.INTERVAL: - value = self.expression(exp.Interval, unit=self._parse_var()) + unit = self._parse_var() + + if not unit: + value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL) + else: + value = self.expression(exp.Interval, unit=unit) if maybe_func and check_func: index2 = self._index @@ -3495,8 +3518,14 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]: - if (any_token and self._advance_any()) or self._match(TokenType.VAR): + def _parse_var( + self, any_token: bool = False, tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: + if ( + (any_token and self._advance_any()) + or self._match(TokenType.VAR) + or (self._match_set(tokens) if tokens else False) + ): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() @@ -3732,19 +3761,26 @@ class Parser(metaclass=_Parser): return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) def _parse_alter(self) -> t.Optional[exp.Expression]: + start = self._prev + if not self._match(TokenType.TABLE): - return self._parse_as_command(self._prev) + return self._parse_as_command(start) exists = self._parse_exists() this = self._parse_table(schema=True) - if not self._curr: - return None - - parser = self.ALTER_PARSERS.get(self._curr.text.upper()) - actions = ensure_list(self._advance() or parser(self)) if parser else [] # type: ignore + if self._next: + self._advance() + parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None - return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) + if parser: + return self.expression( + exp.AlterTable, + this=this, + exists=exists, + actions=ensure_list(parser(self)), + ) + return self._parse_as_command(start) def _parse_show(self) -> t.Optional[exp.Expression]: parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore @@ -3775,7 +3811,15 @@ class Parser(metaclass=_Parser): whens = [] while self._match(TokenType.WHEN): - this = self._parse_conjunction() + matched = not self._match(TokenType.NOT) + self._match_text_seq("MATCHED") + source = ( + False + if self._match_text_seq("BY", "TARGET") + else self._match_text_seq("BY", "SOURCE") + ) + condition = self._parse_conjunction() if self._match(TokenType.AND) else None + self._match(TokenType.THEN) if self._match(TokenType.INSERT): @@ -3800,8 +3844,18 @@ class Parser(metaclass=_Parser): ) elif self._match(TokenType.DELETE): then = self.expression(exp.Var, this=self._prev.text) + else: + then = None - whens.append(self.expression(exp.When, this=this, then=then)) + whens.append( + self.expression( + exp.When, + matched=matched, + source=source, + condition=condition, + then=then, + ) + ) return self.expression( exp.Merge, diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 7a23803..5f4b77d 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -855,11 +855,12 @@ class Tokenizer(metaclass=_Tokenizer): def _scan_keywords(self) -> None: size = 0 word = None - chars: t.Optional[str] = self._text + chars = self._text char = chars prev_space = False skip = False trie = self.KEYWORD_TRIE + single_token = char in self.SINGLE_TOKENS while chars: if skip: @@ -876,6 +877,7 @@ class Tokenizer(metaclass=_Tokenizer): if end < self.size: char = self.sql[end] + single_token = single_token or char in self.SINGLE_TOKENS is_space = char in self.WHITE_SPACE if not is_space or not prev_space: @@ -887,7 +889,9 @@ class Tokenizer(metaclass=_Tokenizer): else: skip = True else: - chars = None + chars = " " + + word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word if not word: if self._char in self.SINGLE_TOKENS: -- cgit v1.2.3