diff options
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 23 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 20 | ||||
-rw-r--r-- | sqlglot/expressions.py | 26 | ||||
-rw-r--r-- | sqlglot/generator.py | 45 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize_identifiers.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 19 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 84 | ||||
-rw-r--r-- | sqlglot/parser.py | 58 | ||||
-rw-r--r-- | sqlglot/tokens.py | 6 |
21 files changed, 256 insertions, 181 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index df9065f..71977dd 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -34,7 +34,7 @@ def _date_add_sql( this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression, unit=unit) + interval = exp.Interval(this=expression.expression.copy(), unit=unit) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func @@ -76,16 +76,12 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope def _create_sql(self: generator.Generator, expression: exp.Create) -> str: kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) + if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): expression = expression.copy() expression.set("kind", "TABLE FUNCTION") - if isinstance( - expression.expression, - ( - exp.Subquery, - exp.Literal, - ), - ): + + if isinstance(expression.expression, (exp.Subquery, exp.Literal)): expression.set("expression", expression.expression.this) return self.create_sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index ce1a486..e6b7743 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -77,7 +77,7 @@ class ClickHouse(Dialect): FUNCTION_PARSERS.pop("MATCH") NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy() - NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY) + NO_PAREN_FUNCTION_PARSERS.pop("ANY") RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, @@ -355,6 +355,7 @@ class ClickHouse(Dialect): def safeconcat_sql(self, expression: exp.SafeConcat) -> str: # Clickhouse errors out if we try to cast a NULL value to TEXT + expression = expression.copy() return self.func( "CONCAT", *[ @@ -389,11 +390,7 @@ class ClickHouse(Dialect): def oncluster_sql(self, expression: exp.OnCluster) -> str: return f"ON CLUSTER {self.sql(expression, 'this')}" - def createable_sql( - self, - expression: exp.Create, - locations: dict[exp.Properties.Location, list[exp.Property]], - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: kind = self.sql(expression, "kind").upper() if kind in self.ON_CLUSTER_TARGETS and locations.get(exp.Properties.Location.POST_NAME): this_name = self.sql(expression.this, "this") @@ -402,4 +399,5 @@ class ClickHouse(Dialect): ) this_schema = self.schema_columns_sql(expression.this) return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}" + return super().createable_sql(expression, locations) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 05e81ce..1d0584c 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -346,7 +346,9 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) + exp.Like( + this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() + ) ) @@ -410,7 +412,7 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: this = self.sql(expression, "this") - struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) + struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True)) return f"{this}.{struct_key}" @@ -571,6 +573,17 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return self.sql(exp.cast(expression.this, "date")) +# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 +def encode_decode_sql( + self: Generator, expression: exp.Expression, name: str, replace: bool = True +) -> str: + charset = expression.args.get("charset") + if charset and charset.name.lower() != "utf-8": + self.unsupported(f"Expected utf-8 character set, got {charset}.") + + return self.func(name, expression.this, expression.args.get("replace") if replace else None) + + def min_or_least(self: Generator, expression: exp.Min) -> str: name = "LEAST" if expression.expressions else "MIN" return rename_func(name)(self, expression) @@ -588,7 +601,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this.expressions[0] self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") - return self.func("sum", exp.func("if", cond, 1, 0)) + return self.func("sum", exp.func("if", cond.copy(), 1, 0)) def trim_sql(self: Generator, expression: exp.Trim) -> str: @@ -625,6 +638,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: + expression = expression.copy() this, *rest_args = expression.expressions for arg in rest_args: this = exp.DPipe(this=this, expression=arg) @@ -674,11 +688,10 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp return names -def simplify_literal(expression: E, copy: bool = True) -> E: +def simplify_literal(expression: E) -> E: if not isinstance(expression.expression, exp.Literal): from sqlglot.optimizer.simplify import simplify - expression = exp.maybe_copy(expression, copy) simplify(expression.expression) return expression diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 26d09ce..1b2681d 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -20,9 +20,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" - ) + return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" return func @@ -145,7 +143,7 @@ class Drill(Dialect): exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression.copy(), unit=exp.var('DAY')))})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("drill"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 219b1aa..5428e86 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( binary_from_function, date_trunc_to_time, datestrtodate_sql, + encode_decode_sql, format_time_lambda, no_comment_column_constraint_sql, no_properties_sql, @@ -32,14 +33,14 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" op = "+" if isinstance(expression, exp.DateAdd) else "-" - return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + return f"{this} {op} {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" # BigQuery -> DuckDB conversion for the DATE function @@ -167,6 +168,16 @@ class DuckDB(Dialect): "XOR": binary_from_function(exp.BitwiseXor), } + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "DECODE": lambda self: self.expression( + exp.Decode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8") + ), + "ENCODE": lambda self: self.expression( + exp.Encode, this=self._parse_conjunction(), charset=exp.Literal.string("utf-8") + ), + } + TYPE_TOKENS = { *parser.Parser.TYPE_TOKENS, TokenType.UBIGINT, @@ -215,7 +226,9 @@ class DuckDB(Dialect): ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", + exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False), exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", + exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False), exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, @@ -228,8 +241,8 @@ class DuckDB(Dialect): exp.MonthsBetween: lambda self, e: self.func( "DATEDIFF", "'month'", - exp.cast(e.expression, "timestamp"), - exp.cast(e.this, "timestamp"), + exp.cast(e.expression, "timestamp", copy=True), + exp.cast(e.this, "timestamp", copy=True), ), exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, @@ -290,7 +303,7 @@ class DuckDB(Dialect): multiplier = 90 if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('day')))})" + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this.copy(), unit=exp.var('day')))})" return super().interval_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 4e84085..aa4d845 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -59,7 +59,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS if expression.expression.is_number: modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) else: - modified_increment = expression.expression + modified_increment = expression.expression.copy() if multiplier != 1: modified_increment = exp.Mul( # type: ignore this=modified_increment, expression=exp.Literal.number(multiplier) @@ -272,8 +272,8 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + NO_PAREN_FUNCTION_PARSERS = { + **parser.Parser.NO_PAREN_FUNCTION_PARSERS, "TRANSFORM": lambda self: self._parse_transform(), } @@ -284,10 +284,12 @@ class Hive(Dialect): ), } - def _parse_transform(self) -> exp.Transform | exp.QueryTransform: - args = self._parse_csv(self._parse_lambda) - self._match_r_paren() + def _parse_transform(self) -> t.Optional[exp.Transform | exp.QueryTransform]: + if not self._match(TokenType.L_PAREN, advance=False): + self._retreat(self._index - 1) + return None + args = self._parse_wrapped_csv(self._parse_lambda) row_format_before = self._parse_row_format(match_row=True) record_writer = None diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a54f076..3cd99e7 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -87,9 +87,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" - ) + return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" return func @@ -522,7 +520,7 @@ class MySQL(Dialect): exp.StrToTime: _str_to_date_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime")), + exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, @@ -556,12 +554,12 @@ class MySQL(Dialect): def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: # MySQL requires simple literal values for its LIMIT clause. - expression = simplify_literal(expression) + expression = simplify_literal(expression.copy()) return super().limit_sql(expression, top=top) def offset_sql(self, expression: exp.Offset) -> str: # MySQL requires simple literal values for its OFFSET clause. - expression = simplify_literal(expression) + expression = simplify_literal(expression.copy()) return super().offset_sql(expression) def xor_sql(self, expression: exp.Xor) -> str: diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index ef100b1..ca44b70 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -40,10 +40,12 @@ DATE_DIFF_FACTOR = { def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: + expression = expression.copy() + this = self.sql(expression, "this") unit = expression.args.get("unit") - expression = simplify_literal(expression.copy(), copy=False).expression + expression = simplify_literal(expression).expression if not isinstance(expression, exp.Literal): self.unsupported("Cannot add non literal") diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 14ec3dd..291b478 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( Dialect, binary_from_function, date_trunc_to_time, + encode_decode_sql, format_time_lambda, if_sql, left_to_substring_sql, @@ -21,7 +22,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL -from sqlglot.errors import UnsupportedError from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -41,6 +41,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + expression = expression.copy() return self.sql( exp.Join( this=exp.Unnest( @@ -59,16 +60,6 @@ def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str: - _ensure_utf8(expression.args["charset"]) - return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) - - -def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str: - _ensure_utf8(expression.args["charset"]) - return f"TO_UTF8({self.sql(expression, 'this')})" - - def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" @@ -106,14 +97,14 @@ def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDat time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") - return exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE").sql(dialect="presto") + return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): - this = exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE") + this = exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE") return self.func( "DATE_ADD", @@ -123,11 +114,6 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s ) -def _ensure_utf8(charset: exp.Literal) -> None: - if charset.name.lower() != "utf-8": - raise UnsupportedError(f"Unsupported charset {charset}") - - def _approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( @@ -288,9 +274,9 @@ class Presto(Dialect): ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", - exp.Decode: _decode_sql, + exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", - exp.Encode: _encode_sql, + exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index f687ba7..cdb8d0d 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -3,7 +3,11 @@ from __future__ import annotations import typing as t from sqlglot import exp, transforms -from sqlglot.dialects.dialect import concat_to_dpipe_sql, rename_func +from sqlglot.dialects.dialect import ( + concat_to_dpipe_sql, + rename_func, + ts_or_ds_to_date_sql, +) from sqlglot.dialects.postgres import Postgres from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -13,6 +17,14 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx return f'{self.sql(expression, "this")}."{expression.expression.name}"' +def _parse_date_add(args: t.List) -> exp.DateAdd: + return exp.DateAdd( + this=exp.TsOrDsToDate(this=seq_get(args, 2)), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ) + + class Redshift(Postgres): # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html RESOLVES_IDENTIFIERS_AS_UPPERCASE = None @@ -32,11 +44,8 @@ class Redshift(Postgres): expression=seq_get(args, 1), unit=exp.var("month"), ), - "DATEADD": lambda args: exp.DateAdd( - this=exp.TsOrDsToDate(this=seq_get(args, 2)), - expression=seq_get(args, 1), - unit=seq_get(args, 0), - ), + "DATEADD": _parse_date_add, + "DATE_ADD": _parse_date_add, "DATEDIFF": lambda args: exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 2)), expression=exp.TsOrDsToDate(this=seq_get(args, 1)), @@ -123,7 +132,7 @@ class Redshift(Postgres): exp.SafeConcat: concat_to_dpipe_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", - exp.TsOrDsToDate: lambda self, e: self.sql(e.this), + exp.TsOrDsToDate: ts_or_ds_to_date_sql("redshift"), } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 499e085..9733a85 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -297,9 +297,10 @@ class Snowflake(Dialect): return super()._parse_id_var(any_token=any_token, tokens=tokens) class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", "$$"] + QUOTES = ["'"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] + RAW_STRINGS = ["$$"] COMMENTS = ["--", "//", ("/*", "*/")] KEYWORDS = { @@ -363,6 +364,7 @@ class Snowflake(Dialect): exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), + exp.StartsWith: rename_func("STARTSWITH"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), @@ -382,6 +384,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), } TYPE_MAPPING = { diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index baa62e8..4f6183c 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -17,6 +17,13 @@ class StarRocks(MySQL): "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=seq_get(args, 1), expression=seq_get(args, 2), unit=seq_get(args, 0) + ), + "REGEXP": exp.RegexpLike.from_arg_list, } class Generator(MySQL.Generator): @@ -32,9 +39,11 @@ class StarRocks(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression + ), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, - exp.DateDiff: rename_func("DATEDIFF"), exp.RegexpLike: rename_func("REGEXP"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimestampTrunc: lambda self, e: self.func( diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 3fac4f5..2be1a62 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least from sqlglot.tokens import TokenType @@ -194,11 +196,7 @@ class Teradata(Dialect): return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})" - def createable_sql( - self, - expression: exp.Create, - locations: dict[exp.Properties.Location, list[exp.Property]], - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: kind = self.sql(expression, "kind").upper() if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME): this_name = self.sql(expression.this, "this") @@ -209,4 +207,5 @@ class Teradata(Dialect): ) this_schema = self.schema_columns_sql(expression.this) return f"{this_name}{this_properties}{self.sep()}{this_schema}" + return super().createable_sql(expression, locations) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 0eb0906..131307f 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -395,6 +395,20 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + """ + T-SQL supports the syntax alias = expression in the SELECT's projection list, + so we transform all parsed Selects to convert their EQ projections into Aliases. + + See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax + """ + return [ + exp.alias_(projection.expression, projection.this.this, copy=False) + if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) + else projection + for projection in super()._parse_projections() + ] + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: """Applies to SQL Server and Azure SQL Database COMMIT [ { TRAN | TRANSACTION } @@ -625,11 +639,7 @@ class TSQL(Dialect): LIMIT_FETCH = "FETCH" - def createable_sql( - self, - expression: exp.Create, - locations: dict[exp.Properties.Location, list[exp.Property]], - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: sql = self.sql(expression, "this") properties = expression.args.get("properties") diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index f8e9fee..c207751 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -192,6 +192,13 @@ class Expression(metaclass=_Expression): return self.text("alias") @property + def alias_column_names(self) -> t.List[str]: + table_alias = self.args.get("alias") + if not table_alias: + return [] + return [c.name for c in table_alias.args.get("columns") or []] + + @property def name(self) -> str: return self.text("this") @@ -884,13 +891,6 @@ class Predicate(Condition): class DerivedTable(Expression): @property - def alias_column_names(self) -> t.List[str]: - table_alias = self.args.get("alias") - if not table_alias: - return [] - return [c.name for c in table_alias.args.get("columns") or []] - - @property def selects(self) -> t.List[Expression]: return self.this.selects if isinstance(self.this, Subqueryable) else [] @@ -4860,8 +4860,18 @@ def maybe_parse( return sqlglot.parse_one(sql, read=dialect, into=into, **opts) +@t.overload +def maybe_copy(instance: None, copy: bool = True) -> None: + ... + + +@t.overload def maybe_copy(instance: E, copy: bool = True) -> E: - return instance.copy() if copy else instance + ... + + +def maybe_copy(instance, copy=True): + return instance.copy() if copy and instance else instance def _is_wrong_expression(expression, into): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ed0a681..95db795 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import typing as t +from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages @@ -676,15 +677,13 @@ class Generator: this = f" {this}" if this else "" return f"UNIQUE{this}" - def createable_sql( - self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]] - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: return self.sql(expression, "this") def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() properties = expression.args.get("properties") - properties_locs = self.locate_properties(properties) if properties else {} + properties_locs = self.locate_properties(properties) if properties else defaultdict() this = self.createable_sql(expression, properties_locs) @@ -970,9 +969,9 @@ class Generator: for p in expression.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] if p_loc == exp.Properties.Location.POST_WITH: - with_properties.append(p) + with_properties.append(p.copy()) elif p_loc == exp.Properties.Location.POST_SCHEMA: - root_properties.append(p) + root_properties.append(p.copy()) return self.root_properties( exp.Properties(expressions=root_properties) @@ -1001,30 +1000,13 @@ class Generator: def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("WITH")) - def locate_properties( - self, properties: exp.Properties - ) -> t.Dict[exp.Properties.Location, list[exp.Property]]: - properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = { - key: [] for key in exp.Properties.Location - } - + def locate_properties(self, properties: exp.Properties) -> t.DefaultDict: + properties_locs = defaultdict(list) for p in properties.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.POST_NAME: - properties_locs[exp.Properties.Location.POST_NAME].append(p) - elif p_loc == exp.Properties.Location.POST_INDEX: - properties_locs[exp.Properties.Location.POST_INDEX].append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA: - properties_locs[exp.Properties.Location.POST_SCHEMA].append(p) - elif p_loc == exp.Properties.Location.POST_WITH: - properties_locs[exp.Properties.Location.POST_WITH].append(p) - elif p_loc == exp.Properties.Location.POST_CREATE: - properties_locs[exp.Properties.Location.POST_CREATE].append(p) - elif p_loc == exp.Properties.Location.POST_ALIAS: - properties_locs[exp.Properties.Location.POST_ALIAS].append(p) - elif p_loc == exp.Properties.Location.POST_EXPRESSION: - properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p) - elif p_loc == exp.Properties.Location.UNSUPPORTED: + if p_loc != exp.Properties.Location.UNSUPPORTED: + properties_locs[p_loc].append(p.copy()) + else: self.unsupported(f"Unsupported property {p.key}") return properties_locs @@ -1646,9 +1628,9 @@ class Generator: with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): - limit = exp.Limit(expression=limit.args.get("count")) + limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count"))) elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): - limit = exp.Fetch(direction="FIRST", count=limit.expression) + limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression)) fetch = isinstance(limit, exp.Fetch) @@ -1955,6 +1937,7 @@ class Generator: return f"PRIMARY KEY ({expressions}){options}" def if_sql(self, expression: exp.If) -> str: + expression = expression.copy() return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: @@ -2261,7 +2244,7 @@ class Generator: def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( - this=exp.Div(this=expression.this, expression=expression.expression), + this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()), to=exp.DataType(this=exp.DataType.Type.INT), ) ) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 9d4860e..54cf02b 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -41,5 +41,6 @@ def normalize_identifiers(expression, dialect=None): Returns: The transformed expression. """ - expression = exp.maybe_parse(expression, dialect=dialect) + if isinstance(expression, str): + expression = exp.to_identifier(expression) return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index c81fd00..b51601f 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -31,6 +31,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) """ # Map of Scope to all columns being selected by outer queries. schema = ensure_schema(schema) + source_column_alias_count = {} referenced_columns = defaultdict(set) # We build the scope tree (which is traversed in DFS postorder), then iterate @@ -38,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) # columns for a particular scope are completely build by the time we get to it. for scope in reversed(traverse_scope(expression)): parent_selections = referenced_columns.get(scope, {SELECT_ALL}) + alias_count = source_column_alias_count.get(scope, 0) - if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots: + if scope.expression.args.get("distinct") or (scope.parent and scope.parent.pivots): # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if # we select from a pivoted source in the parent scope. parent_selections = {SELECT_ALL} @@ -59,7 +61,7 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) if isinstance(scope.expression, exp.Select): if remove_unused_selections: - _remove_unused_selections(scope, parent_selections, schema) + _remove_unused_selections(scope, parent_selections, schema, alias_count) if scope.expression.is_star: continue @@ -72,15 +74,19 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) selects[table_name].add(col_name) # Push the selected columns down to the next scope - for name, (_, source) in scope.selected_sources.items(): + for name, (node, source) in scope.selected_sources.items(): if isinstance(source, Scope): columns = selects.get(name) or set() referenced_columns[source].update(columns) + column_aliases = node.alias_column_names + if column_aliases: + source_column_alias_count[source] = len(column_aliases) + return expression -def _remove_unused_selections(scope, parent_selections, schema): +def _remove_unused_selections(scope, parent_selections, schema, alias_count): order = scope.expression.args.get("order") if order: @@ -93,11 +99,14 @@ def _remove_unused_selections(scope, parent_selections, schema): removed = False star = False + select_all = SELECT_ALL in parent_selections + for selection in scope.expression.selects: name = selection.alias_or_name - if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: + if select_all or name in parent_selections or name in order_refs or alias_count > 0: new_selections.append(selection) + alias_count -= 1 else: if selection.is_star: star = True diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 9c34cef..952999d 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -58,6 +59,7 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables, pseudocolumns) _qualify_outputs(scope) + _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -85,7 +87,7 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> """ Remove table column aliases. - (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) + For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: table_alias = derived_table.args.get("alias") @@ -111,11 +113,11 @@ def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: columns = {} - for k in scope.selected_sources: - if k in ordered: - for column in resolver.get_source_columns(k): - if column not in columns: - columns[column] = k + for source_name in scope.selected_sources: + if source_name in ordered: + for column_name in resolver.get_source_columns(source_name): + if column_name not in columns: + columns[column_name] = source_name source_table = ordered[-1] ordered.append(join_table) @@ -183,6 +185,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: for column, *_ in walk_in_scope(node): if not isinstance(column, exp.Column): continue + table = resolver.get_table(column.name) if resolve_table and not column.table else None alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( @@ -198,7 +201,10 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if literal_index: column.replace(exp.Literal.number(i)) else: - column.replace(alias_expr.copy()) + column = column.replace(exp.paren(alias_expr)) + simplified = simplify_parens(column) + if simplified is not column: + column.replace(simplified) for i, projection in enumerate(scope.expression.selects): replace_columns(projection) @@ -213,7 +219,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: scope.clear_cache() -def _expand_group_by(scope: Scope): +def _expand_group_by(scope: Scope) -> None: expression = scope.expression group = expression.args.get("group") if not group: @@ -223,7 +229,7 @@ def _expand_group_by(scope: Scope): expression.set("group", group) -def _expand_order_by(scope: Scope, resolver: Resolver): +def _expand_order_by(scope: Scope, resolver: Resolver) -> None: order = scope.expression.args.get("order") if not order: return @@ -442,7 +448,7 @@ def _add_replace_columns( replace_columns[id(table)] = columns -def _qualify_outputs(scope: Scope): +def _qualify_outputs(scope: Scope) -> None: """Ensure all output columns are aliased""" new_selections = [] @@ -482,9 +488,9 @@ class Resolver: def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): self.scope = scope self.schema = schema - self._source_columns = None + self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None - self._all_columns = None + self._all_columns: t.Optional[t.Set[str]] = None self._infer_schema = infer_schema def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: @@ -528,7 +534,7 @@ class Resolver: return exp.to_identifier(table_name) @property - def all_columns(self): + def all_columns(self) -> t.Set[str]: """All available columns of all sources in this scope""" if self._all_columns is None: self._all_columns = { @@ -536,53 +542,67 @@ class Resolver: } return self._all_columns - def get_source_columns(self, name, only_visible=False): - """Resolve the source columns for a given source `name`""" + def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]: + """Resolve the source columns for a given source `name`.""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") source = self.scope.sources[name] - # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): - return self.schema.column_names(source, only_visible) + columns = self.schema.column_names(source, only_visible) + elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): + columns = source.expression.alias_column_names + else: + columns = source.expression.named_selects - if isinstance(source, Scope) and isinstance(source.expression, exp.Values): - return source.expression.alias_column_names + node, _ = self.scope.selected_sources.get(name) or (None, None) + if isinstance(node, Scope): + column_aliases = node.expression.alias_column_names + elif isinstance(node, exp.Expression): + column_aliases = node.alias_column_names + else: + column_aliases = [] - # Otherwise, if referencing another scope, return that scope's named selects - return source.expression.named_selects + # If the source's columns are aliased, their aliases shadow the corresponding column names + return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)] - def _get_all_source_columns(self): + def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]: if self._source_columns is None: self._source_columns = { - k: self.get_source_columns(k) - for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) + source_name: self.get_source_columns(source_name) + for source_name, source in itertools.chain( + self.scope.selected_sources.items(), self.scope.lateral_sources.items() + ) } return self._source_columns - def _get_unambiguous_columns(self, source_columns): + def _get_unambiguous_columns( + self, source_columns: t.Dict[str, t.List[str]] + ) -> t.Dict[str, str]: """ Find all the unambiguous columns in sources. Args: - source_columns (dict): Mapping of names to source columns + source_columns: Mapping of names to source columns. + Returns: - dict: Mapping of column name to source name + Mapping of column name to source name. """ if not source_columns: return {} - source_columns = list(source_columns.items()) + source_columns_pairs = list(source_columns.items()) - first_table, first_columns = source_columns[0] + first_table, first_columns = source_columns_pairs[0] unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)} all_columns = set(unambiguous_columns) - for table, columns in source_columns[1:]: + for table, columns in source_columns_pairs[1:]: unique = self._find_unique_columns(columns) ambiguous = set(all_columns).intersection(unique) all_columns.update(columns) + for column in ambiguous: unambiguous_columns.pop(column, None) for column in unique.difference(ambiguous): @@ -591,7 +611,7 @@ class Resolver: return unambiguous_columns @staticmethod - def _find_unique_columns(columns): + def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]: """ Find the unique columns in a list of columns. @@ -601,7 +621,7 @@ class Resolver: This is necessary because duplicate column names are ambiguous. """ - counts = {} + counts: t.Dict[str, int] = {} for column in columns: counts[column] = counts.get(column, 0) + 1 return {column for column, count in counts.items() if count == 1} diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f714c8d..35a1744 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -248,7 +248,6 @@ class Parser(metaclass=_Parser): TokenType.FILTER, TokenType.FORMAT, TokenType.FULL, - TokenType.IF, TokenType.IS, TokenType.ISNULL, TokenType.INTERVAL, @@ -708,14 +707,10 @@ class Parser(metaclass=_Parser): SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} NO_PAREN_FUNCTION_PARSERS = { - TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()), - TokenType.CASE: lambda self: self._parse_case(), - TokenType.IF: lambda self: self._parse_if(), - TokenType.NEXT_VALUE_FOR: lambda self: self.expression( - exp.NextValueFor, - this=self._parse_column(), - order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order), - ), + "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()), + "CASE": lambda self: self._parse_case(), + "IF": lambda self: self._parse_if(), + "NEXT": lambda self: self._parse_next_value_for(), } FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} @@ -1162,7 +1157,7 @@ class Parser(metaclass=_Parser): def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: return ( - self._match(TokenType.IF) + self._match_text_seq("IF") and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) ) @@ -1935,6 +1930,9 @@ class Parser(metaclass=_Parser): # https://prestodb.io/docs/current/sql/values.html return self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) + def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + return self._parse_expressions() + def _parse_select( self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True ) -> t.Optional[exp.Expression]: @@ -1974,14 +1972,14 @@ class Parser(metaclass=_Parser): self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") limit = self._parse_limit(top=True) - expressions = self._parse_expressions() + projections = self._parse_projections() this = self.expression( exp.Select, kind=kind, hint=hint, distinct=distinct, - expressions=expressions, + expressions=projections, limit=limit, ) this.comments = comments @@ -3021,8 +3019,12 @@ class Parser(metaclass=_Parser): while True: if self._match_set(self.BITWISE): this = self.expression( - self.BITWISE[self._prev.token_type], this=this, expression=self._parse_term() + self.BITWISE[self._prev.token_type], + this=this, + expression=self._parse_term(), ) + elif self._match(TokenType.DQMARK): + this = self.expression(exp.Coalesce, this=this, expressions=self._parse_term()) elif self._match_pair(TokenType.LT, TokenType.LT): this = self.expression( exp.BitwiseLeftShift, this=this, expression=self._parse_term() @@ -3322,9 +3324,13 @@ class Parser(metaclass=_Parser): return None token_type = self._curr.token_type + this = self._curr.text + upper = this.upper() - if optional_parens and self._match_set(self.NO_PAREN_FUNCTION_PARSERS): - return self.NO_PAREN_FUNCTION_PARSERS[token_type](self) + parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper) + if optional_parens and parser: + self._advance() + return parser(self) if not self._next or self._next.token_type != TokenType.L_PAREN: if optional_parens and token_type in self.NO_PAREN_FUNCTIONS: @@ -3336,12 +3342,9 @@ class Parser(metaclass=_Parser): if token_type not in self.FUNC_TOKENS: return None - this = self._curr.text - upper = this.upper() self._advance(2) parser = self.FUNCTION_PARSERS.get(upper) - if parser and not anonymous: this = parser(self) else: @@ -3368,7 +3371,7 @@ class Parser(metaclass=_Parser): else: this = self.expression(exp.Anonymous, this=this, expressions=args) - self._match(TokenType.R_PAREN, expression=this) + self._match_r_paren(this) return self._parse_window(this) def _parse_function_parameter(self) -> t.Optional[exp.Expression]: @@ -3703,7 +3706,11 @@ class Parser(metaclass=_Parser): self.expression(exp.Slice, expression=self._parse_conjunction()) ] else: - expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction())) + expressions = self._parse_csv( + lambda: self._parse_slice( + self._parse_alias(self._parse_conjunction(), explicit=True) + ) + ) # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs if bracket_kind == TokenType.L_BRACE: @@ -3770,6 +3777,17 @@ class Parser(metaclass=_Parser): return self._parse_window(this) + def _parse_next_value_for(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("VALUE", "FOR"): + self._retreat(self._index - 1) + return None + + return self.expression( + exp.NextValueFor, + this=self._parse_column(), + order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order), + ) + def _parse_extract(self) -> exp.Extract: this = self._parse_function() or self._parse_var() or self._parse_type() diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 729e47f..81bcc0b 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -21,6 +21,7 @@ class TokenType(AutoName): PLUS = auto() COLON = auto() DCOLON = auto() + DQMARK = auto() SEMICOLON = auto() STAR = auto() BACKSLASH = auto() @@ -215,7 +216,6 @@ class TokenType(AutoName): GROUPING_SETS = auto() HAVING = auto() HINT = auto() - IF = auto() IGNORE = auto() ILIKE = auto() ILIKE_ANY = auto() @@ -248,7 +248,6 @@ class TokenType(AutoName): MOD = auto() NATURAL = auto() NEXT = auto() - NEXT_VALUE_FOR = auto() NOTNULL = auto() NULL = auto() OFFSET = auto() @@ -504,6 +503,7 @@ class Tokenizer(metaclass=_Tokenizer): "#>>": TokenType.DHASH_ARROW, "<->": TokenType.LR_ARROW, "&&": TokenType.DAMP, + "??": TokenType.DQMARK, "ALL": TokenType.ALL, "ALWAYS": TokenType.ALWAYS, "AND": TokenType.AND, @@ -563,7 +563,6 @@ class Tokenizer(metaclass=_Tokenizer): "GROUP BY": TokenType.GROUP_BY, "GROUPING SETS": TokenType.GROUPING_SETS, "HAVING": TokenType.HAVING, - "IF": TokenType.IF, "ILIKE": TokenType.ILIKE, "IN": TokenType.IN, "INDEX": TokenType.INDEX, @@ -586,7 +585,6 @@ class Tokenizer(metaclass=_Tokenizer): "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, - "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR, "NOT": TokenType.NOT, "NOTNULL": TokenType.NOTNULL, "NULL": TokenType.NULL, |