diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-15 05:02:18 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-15 05:02:18 +0000 |
commit | 41f1f5740d2140bfd3b2a282ca1087a4b576679a (patch) | |
tree | 0b1eb5ba5c759d08b05d56e50675784b6170f955 /sqlglot | |
parent | Releasing debian version 23.7.0-1. (diff) | |
download | sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.tar.xz sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.zip |
Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
26 files changed, 548 insertions, 297 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index b4dd2c6..81b7d61 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -536,7 +536,7 @@ def year(col: ColumnOrName) -> Column: def quarter(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "QUARTER") + return Column.invoke_expression_over_column(col, expression.Quarter) def month(col: ColumnOrName) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 2167ba2..a7b4895 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import ( build_formatted_time, filter_array_using_unnest, if_sql, - inline_array_sql, + inline_array_unless_query, max_or_greatest, min_or_least, no_ilike_sql, @@ -80,29 +80,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: return self.create_sql(expression) -def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: - """Remove references to unnest table aliases since bigquery doesn't allow them. - - These are added by the optimizer's qualify_column step. - """ - from sqlglot.optimizer.scope import find_all_in_scope - - if isinstance(expression, exp.Select): - unnest_aliases = { - unnest.alias - for unnest in find_all_in_scope(expression, exp.Unnest) - if isinstance(unnest.parent, (exp.From, exp.Join)) - } - if unnest_aliases: - for column in expression.find_all(exp.Column): - if column.table in unnest_aliases: - column.set("table", None) - elif column.db in unnest_aliases: - column.set("db", None) - - return expression - - # https://issuetracker.google.com/issues/162294746 # workaround for bigquery bug when grouping by an expression and then ordering # WITH x AS (SELECT 1 y) @@ -197,8 +174,8 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str: - expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True)) - expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True)) + expression.this.replace(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) + expression.expression.replace(exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP)) unit = unit_to_var(expression) return self.func("DATE_DIFF", expression.this, expression.expression, unit) @@ -214,7 +191,9 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s if scale == exp.UnixToTime.MICROS: return self.func("TIMESTAMP_MICROS", timestamp) - unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64") + unix_seconds = exp.cast( + exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT + ) return self.func("TIMESTAMP_SECONDS", unix_seconds) @@ -576,6 +555,7 @@ class BigQuery(Dialect): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), + exp.Array: inline_array_unless_query, exp.ArrayContains: _array_contains_sql, exp.ArrayFilter: filter_array_using_unnest, exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -629,7 +609,7 @@ class BigQuery(Dialect): exp.Select: transforms.preprocess( [ transforms.explode_to_unnest(), - _unqualify_unnest, + transforms.unqualify_unnest, transforms.eliminate_distinct_on, _alias_ordered_group, transforms.eliminate_semi_and_anti_joins, @@ -843,13 +823,6 @@ class BigQuery(Dialect): def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") - def array_sql(self, expression: exp.Array) -> str: - first_arg = seq_get(expression.expressions, 0) - if isinstance(first_arg, exp.Query): - return f"ARRAY{self.wrap(self.sql(first_arg))}" - - return inline_array_sql(self, expression) - def bracket_sql(self, expression: exp.Bracket) -> str: this = expression.this expressions = expression.expressions diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 631dc30..34ee529 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -629,7 +629,8 @@ class ClickHouse(Dialect): exp.CountIf: rename_func("countIf"), exp.CompressColumnConstraint: lambda self, e: f"CODEC({self.expressions(e, key='this', flat=True)})", - exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}", + exp.ComputedColumnConstraint: lambda self, + e: f"{'MATERIALIZED' if e.args.get('persisted') else 'ALIAS'} {self.sql(e, 'this')}", exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), exp.DateAdd: date_delta_sql("DATE_ADD"), exp.DateDiff: date_delta_sql("DATE_DIFF"), @@ -667,6 +668,7 @@ class ClickHouse(Dialect): TABLE_HINTS = False EXPLICIT_UNION = True GROUPINGS_SEP = "" + OUTER_UNION_MODIFIERS = False # there's no list in docs, but it can be found in Clickhouse code # see `ClickHouse/src/Parsers/ParserCreate*.cpp` diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 81057c2..5a47438 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -562,7 +562,7 @@ def if_sql( def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: this = expression.this if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: - this.replace(exp.cast(this, "json")) + this.replace(exp.cast(this, exp.DataType.Type.JSON)) return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") @@ -571,6 +571,13 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: return f"[{self.expressions(expression, flat=True)}]" +def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: + elem = seq_get(expression.expressions, 0) + if isinstance(elem, exp.Expression) and elem.find(exp.Query): + return self.func("ARRAY", elem) + return inline_array_sql(self, expression) + + def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) @@ -765,11 +772,11 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: from sqlglot.optimizer.annotate_types import annotate_types target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP - return self.sql(exp.cast(expression.this, to=target_type)) + return self.sql(exp.cast(expression.this, target_type)) if expression.text("expression").lower() in TIMEZONES: return self.sql( exp.AtTimeZone( - this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP), + this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), zone=expression.expression, ) ) @@ -806,11 +813,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str: - return self.sql(exp.cast(expression.this, "timestamp")) + return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: - return self.sql(exp.cast(expression.this, "date")) + return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) # Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 @@ -1023,7 +1030,7 @@ def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") - return self.sql(exp.cast(minus_one_day, "date")) + return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 0a00d92..06f49d5 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -19,7 +19,7 @@ def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Drill.DATE_FORMAT: - return self.sql(exp.cast(this, "date")) + return self.sql(exp.cast(this, exp.DataType.Type.DATE)) return self.func("TO_DATE", this, time_format) @@ -134,7 +134,7 @@ class Drill(Dialect): [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] ), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")), + exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)), exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6a1d07a..6486dda 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, encode_decode_sql, build_formatted_time, - inline_array_sql, + inline_array_unless_query, no_comment_column_constraint_sql, no_safe_divide_sql, no_timestamp_sql, @@ -312,6 +312,15 @@ class DuckDB(Dialect): ), } + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + bracket = super()._parse_bracket(this) + if isinstance(bracket, exp.Bracket): + bracket.set("returns_list_for_maps", True) + + return bracket + def _parse_map(self) -> exp.ToMap | exp.Map: if self._match(TokenType.L_BRACE, advance=False): return self.expression(exp.ToMap, this=self._parse_bracket()) @@ -370,11 +379,7 @@ class DuckDB(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: ( - self.func("ARRAY", e.expressions[0]) - if e.expressions and e.expressions[0].find(exp.Select) - else inline_array_sql(self, e) - ), + exp.Array: inline_array_unless_query, exp.ArrayFilter: rename_func("LIST_FILTER"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), @@ -416,8 +421,8 @@ class DuckDB(Dialect): exp.MonthsBetween: lambda self, e: self.func( "DATEDIFF", "'month'", - exp.cast(e.expression, "timestamp", copy=True), - exp.cast(e.this, "timestamp", copy=True), + exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True), + exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True), ), exp.ParseJSON: rename_func("JSON"), exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"), @@ -452,9 +457,11 @@ class DuckDB(Dialect): "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this ), exp.TimestampTrunc: timestamptrunc_sql, - exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")), + exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)), exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")), + exp.TimeStrToUnix: lambda self, e: self.func( + "EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP) + ), exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)), exp.TimeToUnix: rename_func("EPOCH"), exp.TsOrDiToDi: lambda self, @@ -463,8 +470,8 @@ class DuckDB(Dialect): exp.TsOrDsDiff: lambda self, e: self.func( "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", - exp.cast(e.expression, "TIMESTAMP"), - exp.cast(e.this, "TIMESTAMP"), + exp.cast(e.expression, exp.DataType.Type.TIMESTAMP), + exp.cast(e.this, exp.DataType.Type.TIMESTAMP), ), exp.UnixToStr: lambda self, e: self.func( "STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e) @@ -593,7 +600,19 @@ class DuckDB(Dialect): return super().generateseries_sql(expression) def bracket_sql(self, expression: exp.Bracket) -> str: - if isinstance(expression.this, exp.Array): - expression.this.replace(exp.paren(expression.this)) + this = expression.this + if isinstance(this, exp.Array): + this.replace(exp.paren(this)) + + bracket = super().bracket_sql(expression) + + if not expression.args.get("returns_list_for_maps"): + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + this = annotate_types(this) + + if this.is_type(exp.DataType.Type.MAP): + bracket = f"({bracket})[1]" - return super().bracket_sql(expression) + return bracket diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1d53346..611a155 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -710,7 +710,9 @@ class MySQL(Dialect): ), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), + exp.TimeStrToTime: lambda self, e: self.sql( + exp.cast(e.this, exp.DataType.Type.DATETIME, copy=True) + ), exp.TimeToStr: _remove_ts_or_ds_to_date( lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)) ), diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 11398ed..7cbcc23 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -510,6 +510,9 @@ class Postgres(Dialect): exp.TsOrDsAdd: _date_add_sql("+"), exp.TsOrDsDiff: _date_diff_sql, exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this), + exp.TimeToUnix: lambda self, e: self.func( + "DATE_PART", exp.Literal.string("epoch"), e.this + ), exp.VariancePop: rename_func("VAR_POP"), exp.Variance: rename_func("VAR_SAMP"), exp.Xor: bool_xor_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 25bba96..6c23bdf 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -90,8 +90,10 @@ def _str_to_time_sql( def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): - return self.sql(exp.cast(_str_to_time_sql(self, expression), "DATE")) - return self.sql(exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE")) + return self.sql(exp.cast(_str_to_time_sql(self, expression), exp.DataType.Type.DATE)) + return self.sql( + exp.cast(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), exp.DataType.Type.DATE) + ) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: @@ -101,8 +103,8 @@ def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: - this = exp.cast(expression.this, "TIMESTAMP") - expr = exp.cast(expression.expression, "TIMESTAMP") + this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMP) + expr = exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP) unit = unit_to_str(expression) return self.func("DATE_DIFF", unit, expr, this) @@ -222,6 +224,8 @@ class Presto(Dialect): "IPPREFIX": TokenType.IPPREFIX, } + KEYWORDS.pop("QUALIFY") + class Parser(parser.Parser): VALUES_FOLLOWED_BY_PAREN = False @@ -445,7 +449,7 @@ class Presto(Dialect): # timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback, # which seems to be using the same time mapping as Hive, as per: # https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html - value_as_text = exp.cast(expression.this, "text") + value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT) parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression)) parse_with_tz = self.func( "PARSE_DATETIME", diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py index 3005753..3ee91a8 100644 --- a/sqlglot/dialects/prql.py +++ b/sqlglot/dialects/prql.py @@ -7,7 +7,13 @@ from sqlglot.dialects.dialect import Dialect from sqlglot.tokens import TokenType +def _select_all(table: exp.Expression) -> t.Optional[exp.Select]: + return exp.select("*").from_(table, copy=False) if table else None + + class PRQL(Dialect): + DPIPE_IS_STRING_CONCAT = False + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ["`"] QUOTES = ["'", '"'] @@ -26,10 +32,27 @@ class PRQL(Dialect): } class Parser(parser.Parser): + CONJUNCTION = { + **parser.Parser.CONJUNCTION, + TokenType.DAMP: exp.And, + TokenType.DPIPE: exp.Or, + } + TRANSFORM_PARSERS = { "DERIVE": lambda self, query: self._parse_selection(query), "SELECT": lambda self, query: self._parse_selection(query, append=False), "TAKE": lambda self, query: self._parse_take(query), + "FILTER": lambda self, query: query.where(self._parse_conjunction()), + "APPEND": lambda self, query: query.union( + _select_all(self._parse_table()), distinct=False, copy=False + ), + "REMOVE": lambda self, query: query.except_( + _select_all(self._parse_table()), distinct=False, copy=False + ), + "INTERSECT": lambda self, query: query.intersect( + _select_all(self._parse_table()), distinct=False, copy=False + ), + "SORT": lambda self, query: self._parse_order_by(query), } def _parse_statement(self) -> t.Optional[exp.Expression]: @@ -81,6 +104,24 @@ class PRQL(Dialect): num = self._parse_number() # TODO: TAKE for ranges a..b return query.limit(num) if num else None + def _parse_ordered( + self, parse_method: t.Optional[t.Callable] = None + ) -> t.Optional[exp.Ordered]: + asc = self._match(TokenType.PLUS) + desc = self._match(TokenType.DASH) or (asc and False) + term = term = super()._parse_ordered(parse_method=parse_method) + if term and desc: + term.set("desc", True) + term.set("nulls_first", False) + return term + + def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]: + l_brace = self._match(TokenType.L_BRACE) + expressions = self._parse_csv(self._parse_ordered) + if l_brace and not self._match(TokenType.R_BRACE): + self.raise_error("Expecting }") + return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False) + def _parse_expression(self) -> t.Optional[exp.Expression]: if self._next and self._next.token_type == TokenType.ALIAS: alias = self._parse_id_var(True) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 1f0c411..7a86c61 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -167,7 +167,11 @@ class Redshift(Postgres): exp.GroupConcat: rename_func("LISTAGG"), exp.ParseJSON: rename_func("JSON_PARSE"), exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_semi_and_anti_joins, + transforms.unqualify_unnest, + ] ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", @@ -203,7 +207,7 @@ class Redshift(Postgres): return "" arg = self.sql(seq_get(args, 0)) - alias = self.expressions(expression.args.get("alias"), key="columns") + alias = self.expressions(expression.args.get("alias"), key="columns", flat=True) return f"{arg} AS {alias}" if alias else arg def with_properties(self, properties: exp.Properties) -> str: diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 73a9166..41d5b65 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -818,7 +818,7 @@ class Snowflake(Dialect): exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( - "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) + "TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e) ), exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.ToArray: rename_func("TO_ARRAY"), diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 88b5ddc..9bb9a5c 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -6,7 +6,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import rename_func, unit_to_var from sqlglot.dialects.hive import _build_with_ignore_nulls from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider -from sqlglot.helper import seq_get +from sqlglot.helper import ensure_list, seq_get from sqlglot.transforms import ( ctas_with_tmp_tables_to_create_tmp_view, remove_unique_constraints, @@ -63,6 +63,9 @@ class Spark(Spark2): **Spark2.Parser.FUNCTIONS, "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), "DATEDIFF": _build_datediff, + "TRY_ELEMENT_AT": lambda args: exp.Bracket( + this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True + ), } def _parse_generated_as_identity( @@ -112,6 +115,13 @@ class Spark(Spark2): TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) + def bracket_sql(self, expression: exp.Bracket) -> str: + if expression.args.get("safe"): + key = seq_get(self.bracket_offset_expressions(expression), 0) + return self.func("TRY_ELEMENT_AT", expression.this, key) + + return super().bracket_sql(expression) + def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})" diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 069916f..5264f39 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -48,7 +48,7 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str timestamp = expression.this if scale is None: - return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp")) + return self.sql(exp.cast(exp.func("from_unixtime", timestamp), exp.DataType.Type.TIMESTAMP)) if scale == exp.UnixToTime.SECONDS: return self.func("TIMESTAMP_SECONDS", timestamp) if scale == exp.UnixToTime.MILLIS: @@ -129,11 +129,7 @@ class Spark2(Hive): "DOUBLE": _build_as_cast("double"), "FLOAT": _build_as_cast("float"), "FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone( - this=exp.cast_unless( - seq_get(args, 0) or exp.Var(this=""), - exp.DataType.build("timestamp"), - exp.DataType.build("timestamp"), - ), + this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP), zone=seq_get(args, 1), ), "INT": _build_as_cast("int"), @@ -150,11 +146,7 @@ class Spark2(Hive): ), "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone( - this=exp.cast_unless( - seq_get(args, 0) or exp.Var(this=""), - exp.DataType.build("timestamp"), - exp.DataType.build("timestamp"), - ), + this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP), zone=seq_get(args, 1), ), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index a65e10e..feb2097 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -13,6 +13,29 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType +def _date_add_sql( + kind: t.Literal["+", "-"], +) -> t.Callable[[Teradata.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: Teradata.Generator, expression: exp.DateAdd | exp.DateSub) -> str: + this = self.sql(expression, "this") + unit = expression.args.get("unit") + value = self._simplify_unless_literal(expression.expression) + + if not isinstance(value, exp.Literal): + self.unsupported("Cannot add non literal") + + if value.is_negative: + kind_to_op = {"+": "-", "-": "+"} + value = exp.Literal.string(value.name[1:]) + else: + kind_to_op = {"+": "+", "-": "-"} + value.set("is_string", True) + + return f"{this} {kind_to_op[kind]} {self.sql(exp.Interval(this=value, unit=unit))}" + + return func + + class Teradata(Dialect): SUPPORTS_SEMI_ANTI_JOIN = False TYPED_DIVISION = True @@ -189,6 +212,7 @@ class Teradata(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", } PROPERTIES_LOCATION = { @@ -214,6 +238,10 @@ class Teradata(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.ToNumber: to_number_with_nls_param, exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.DateAdd: _date_add_sql("+"), + exp.DateSub: _date_add_sql("-"), + exp.Quarter: lambda self, e: self.sql(exp.Extract(this="QUARTER", expression=e.this)), } def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: @@ -276,3 +304,25 @@ class Teradata(Dialect): return f"{this_name}{this_properties}{self.sep()}{this_schema}" return super().createable_sql(expression, locations) + + def extract_sql(self, expression: exp.Extract) -> str: + this = self.sql(expression, "this") + if this.upper() != "QUARTER": + return super().extract_sql(expression) + + to_char = exp.func("to_char", expression.expression, exp.Literal.string("Q")) + return self.sql(exp.cast(to_char, exp.DataType.Type.INT)) + + def interval_sql(self, expression: exp.Interval) -> str: + multiplier = 0 + unit = expression.text("unit") + + if unit.startswith("WEEK"): + multiplier = 7 + elif unit.startswith("QUARTER"): + multiplier = 90 + + if multiplier: + return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})" + + return super().interval_sql(expression) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 8e06be6..6eed46d 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -109,7 +109,7 @@ def _build_formatted_time( assert len(args) == 2 return exp_class( - this=exp.cast(args[1], "datetime"), + this=exp.cast(args[1], exp.DataType.Type.DATETIME), format=exp.Literal.string( format_time( args[0].name.lower(), @@ -726,6 +726,7 @@ class TSQL(Dialect): SUPPORTS_SELECT_INTO = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False SUPPORTS_TO_NUMBER = False + OUTER_UNION_MODIFIERS = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -882,13 +883,6 @@ class TSQL(Dialect): return rename_func("DATETIMEFROMPARTS")(self, expression) - def set_operations(self, expression: exp.Union) -> str: - limit = expression.args.get("limit") - if limit: - return self.sql(expression.limit(limit.pop(), copy=False)) - - return super().set_operations(expression) - def setitem_sql(self, expression: exp.SetItem) -> str: this = expression.this if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter): diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index e79c04b..5adbb1e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -58,6 +58,7 @@ class _Expression(type): SQLGLOT_META = "sqlglot.meta" TABLE_PARTS = ("this", "db", "catalog") +COLUMN_PARTS = ("this", "table", "db", "catalog") class Expression(metaclass=_Expression): @@ -176,6 +177,15 @@ class Expression(metaclass=_Expression): return isinstance(self, Literal) and not self.args["is_string"] @property + def is_negative(self) -> bool: + """ + Checks whether an expression is negative. + + Handles both exp.Neg and Literal numbers with "-" which come from optimizer.simplify. + """ + return isinstance(self, Neg) or (self.is_number and self.this.startswith("-")) + + @property def is_int(self) -> bool: """ Checks whether a Literal expression is an integer. @@ -845,10 +855,14 @@ class Expression(metaclass=_Expression): copy: bool = True, **opts, ) -> In: + subquery = maybe_parse(query, copy=copy, **opts) if query else None + if subquery and not isinstance(subquery, Subquery): + subquery = subquery.subquery(copy=False) + return In( this=maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], - query=maybe_parse(query, copy=copy, **opts) if query else None, + query=subquery, unnest=( Unnest( expressions=[ @@ -1018,14 +1032,14 @@ class Query(Expression): return Subquery(this=instance, alias=alias) def limit( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: + self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Q: """ Adds a LIMIT clause to this query. Example: >>> select("1").union(select("1")).limit(1).sql() - 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' + 'SELECT 1 UNION SELECT 1 LIMIT 1' Args: expression: the SQL code string to parse. @@ -1039,10 +1053,90 @@ class Query(Expression): Returns: A limited Select expression. """ - return ( - select("*") - .from_(self.subquery(alias="_l_0", copy=copy)) - .limit(expression, dialect=dialect, copy=False, **opts) + return _apply_builder( + expression=expression, + instance=self, + arg="limit", + into=Limit, + prefix="LIMIT", + dialect=dialect, + copy=copy, + into_arg="expression", + **opts, + ) + + def offset( + self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Q: + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").offset(10).sql() + 'SELECT x FROM tbl OFFSET 10' + + Args: + expression: the SQL code string to parse. + This can also be an integer. + If a `Offset` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Offset`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="offset", + into=Offset, + prefix="OFFSET", + dialect=dialect, + copy=copy, + into_arg="expression", + **opts, + ) + + def order_by( + self: Q, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Set the ORDER BY expression. + + Example: + >>> Select().from_("tbl").select("x").order_by("x DESC").sql() + 'SELECT x FROM tbl ORDER BY x DESC' + + Args: + *expressions: the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Order`. + append: if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="order", + append=append, + copy=copy, + prefix="ORDER BY", + into=Order, + dialect=dialect, + **opts, ) @property @@ -1536,7 +1630,13 @@ class SwapTable(Expression): class Comment(Expression): - arg_types = {"this": True, "kind": True, "expression": True, "exists": False} + arg_types = { + "this": True, + "kind": True, + "expression": True, + "exists": False, + "materialized": False, + } class Comprehension(Expression): @@ -1642,6 +1742,10 @@ class ExcludeColumnConstraint(ColumnConstraintKind): pass +class EphemeralColumnConstraint(ColumnConstraintKind): + arg_types = {"this": False} + + class WithOperator(Expression): arg_types = {"this": True, "op": True} @@ -2221,6 +2325,13 @@ class Lateral(UDTF): } +class MatchRecognizeMeasure(Expression): + arg_types = { + "this": True, + "window_frame": False, + } + + class MatchRecognize(Expression): arg_types = { "partition_by": False, @@ -3051,46 +3162,6 @@ class Select(Query): **opts, ) - def order_by( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Set the ORDER BY expression. - - Example: - >>> Select().from_("tbl").select("x").order_by("x DESC").sql() - 'SELECT x FROM tbl ORDER BY x DESC' - - Args: - *expressions: the SQL code strings to parse. - If a `Group` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Order`. - append: if `True`, add to any existing expressions. - Otherwise, this flattens all the `Order` expression into a single expression. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_child_list_builder( - *expressions, - instance=self, - arg="order", - append=append, - copy=copy, - prefix="ORDER BY", - into=Order, - dialect=dialect, - **opts, - ) - def sort_by( self, *expressions: t.Optional[ExpOrStr], @@ -3171,55 +3242,6 @@ class Select(Query): **opts, ) - def limit( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - return _apply_builder( - expression=expression, - instance=self, - arg="limit", - into=Limit, - prefix="LIMIT", - dialect=dialect, - copy=copy, - into_arg="expression", - **opts, - ) - - def offset( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - """ - Set the OFFSET expression. - - Example: - >>> Select().from_("tbl").select("x").offset(10).sql() - 'SELECT x FROM tbl OFFSET 10' - - Args: - expression: the SQL code string to parse. - This can also be an integer. - If a `Offset` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Offset`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="offset", - into=Offset, - prefix="OFFSET", - dialect=dialect, - copy=copy, - into_arg="expression", - **opts, - ) - def select( self, *expressions: t.Optional[ExpOrStr], @@ -4214,7 +4236,7 @@ class Dot(Binary): parts.reverse() - for arg in ("this", "table", "db", "catalog"): + for arg in COLUMN_PARTS: part = this.args.get(arg) if isinstance(part, Expression): @@ -4395,7 +4417,13 @@ class Between(Predicate): class Bracket(Condition): # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator - arg_types = {"this": True, "expressions": True, "offset": False, "safe": False} + arg_types = { + "this": True, + "expressions": True, + "offset": False, + "safe": False, + "returns_list_for_maps": False, + } @property def output_name(self) -> str: @@ -5458,6 +5486,10 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} +class Quarter(Func): + pass + + class Rand(Func): _sql_names = ["RAND", "RANDOM"] arg_types = {"this": False} @@ -6620,17 +6652,9 @@ def to_interval(interval: str | Literal) -> Interval: ) -@t.overload -def to_table(sql_path: str | Table, **kwargs) -> Table: ... - - -@t.overload -def to_table(sql_path: None, **kwargs) -> None: ... - - def to_table( - sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs -) -> t.Optional[Table]: + sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs +) -> Table: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. If a table is passed in then that table is returned. @@ -6644,35 +6668,54 @@ def to_table( Returns: A table expression. """ - if sql_path is None or isinstance(sql_path, Table): + if isinstance(sql_path, Table): return maybe_copy(sql_path, copy=copy) - if not isinstance(sql_path, str): - raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") table = maybe_parse(sql_path, into=Table, dialect=dialect) - if table: - for k, v in kwargs.items(): - table.set(k, v) + + for k, v in kwargs.items(): + table.set(k, v) return table -def to_column(sql_path: str | Column, **kwargs) -> Column: +def to_column( + sql_path: str | Column, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **kwargs, +) -> Column: """ - Create a column from a `[table].[column]` sql path. Schema is optional. - + Create a column from a `[table].[column]` sql path. Table is optional. If a column is passed in then that column is returned. Args: - sql_path: `[table].[column]` string + sql_path: a `[table].[column]` string. + quoted: Whether or not to force quote identifiers. + dialect: the source dialect according to which the column name will be parsed. + copy: Whether to copy a column if it is passed in. + kwargs: the kwargs to instantiate the resulting `Column` expression with. + Returns: - Table: A column expression + A column expression. """ - if sql_path is None or isinstance(sql_path, Column): - return sql_path - if not isinstance(sql_path, str): - raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore + if isinstance(sql_path, Column): + return maybe_copy(sql_path, copy=copy) + + try: + col = maybe_parse(sql_path, into=Column, dialect=dialect) + except ParseError: + return column(*reversed(sql_path.split(".")), quoted=quoted, **kwargs) + + for k, v in kwargs.items(): + col.set(k, v) + + if quoted: + for i in col.find_all(Identifier): + i.set("quoted", True) + + return col def alias_( @@ -6756,7 +6799,7 @@ def subquery( A new Select instance with the subquery expression included. """ - expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias) + expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias, **opts) return Select().from_(expression, dialect=dialect, **opts) @@ -6821,7 +6864,9 @@ def column( ) if fields: - this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields))) + this = Dot.build( + (this, *(to_identifier(field, quoted=quoted, copy=copy) for field in fields)) + ) return this @@ -6840,11 +6885,16 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast Returns: The new Cast instance. """ - expression = maybe_parse(expression, copy=copy, **opts) + expr = maybe_parse(expression, copy=copy, **opts) data_type = DataType.build(to, copy=copy, **opts) - expression = Cast(this=expression, to=data_type) - expression.type = data_type - return expression + + if expr.is_type(data_type): + return expr + + expr = Cast(this=expr, to=data_type) + expr.type = data_type + + return expr def table_( @@ -6931,18 +6981,23 @@ def var(name: t.Optional[ExpOrStr]) -> Var: return Var(this=name) -def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: +def rename_table( + old_name: str | Table, + new_name: str | Table, + dialect: DialectType = None, +) -> AlterTable: """Build ALTER TABLE... RENAME... expression Args: old_name: The old name of the table new_name: The new name of the table + dialect: The dialect to parse the table. Returns: Alter table expression """ - old_table = to_table(old_name) - new_table = to_table(new_name) + old_table = to_table(old_name, dialect=dialect) + new_table = to_table(new_name, dialect=dialect) return AlterTable( this=old_table, actions=[ @@ -6956,6 +7011,7 @@ def rename_column( old_column_name: str | Column, new_column_name: str | Column, exists: t.Optional[bool] = None, + dialect: DialectType = None, ) -> AlterTable: """Build ALTER TABLE... RENAME COLUMN... expression @@ -6964,13 +7020,14 @@ def rename_column( old_column: The old name of the column new_column: The new name of the column exists: Whether to add the `IF EXISTS` clause + dialect: The dialect to parse the table/column. Returns: Alter table expression """ - table = to_table(table_name) - old_column = to_column(old_column_name) - new_column = to_column(new_column_name) + table = to_table(table_name, dialect=dialect) + old_column = to_column(old_column_name, dialect=dialect) + new_column = to_column(new_column_name, dialect=dialect) return AlterTable( this=table, actions=[ @@ -7366,27 +7423,6 @@ def case( return Case(this=this, ifs=[]) -def cast_unless( - expression: ExpOrStr, - to: DATA_TYPE, - *types: DATA_TYPE, - **opts: t.Any, -) -> Expression | Cast: - """ - Cast an expression to a data type unless it is a specified type. - - Args: - expression: The expression to cast. - to: The data type to cast to. - **types: The types to exclude from casting. - **opts: Extra keyword arguments for parsing `expression` - """ - expr = maybe_parse(expression, **opts) - if expr.is_type(*types): - return expr - return cast(expr, to, **opts) - - def array( *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs ) -> Array: diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 76d9b5d..b7da18b 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -89,6 +89,8 @@ class Generator(metaclass=_Generator): exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", + exp.EphemeralColumnConstraint: lambda self, + e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}", exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}", exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.ExternalProperty: lambda *_: "EXTERNAL", @@ -332,6 +334,11 @@ class Generator(metaclass=_Generator): # Whether the function TO_NUMBER is supported SUPPORTS_TO_NUMBER = True + # Whether or not union modifiers apply to the outer union or select. + # SELECT * FROM x UNION SELECT * FROM y LIMIT 1 + # True means limit 1 happens after the union, False means it it happens on y. + OUTER_UNION_MODIFIERS = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -1801,10 +1808,15 @@ class Generator(metaclass=_Generator): return f"{self.seg('FROM')} {self.sql(expression, 'this')}" def group_sql(self, expression: exp.Group) -> str: - group_by = self.op_expressions("GROUP BY", expression) + group_by_all = expression.args.get("all") + if group_by_all is True: + modifier = " ALL" + elif group_by_all is False: + modifier = " DISTINCT" + else: + modifier = "" - if expression.args.get("all"): - return f"{group_by} ALL" + group_by = self.op_expressions(f"GROUP BY{modifier}", expression) grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) grouping_sets = ( @@ -2109,6 +2121,14 @@ class Generator(metaclass=_Generator): return f"{this}{sort_order}{nulls_sort_change}{with_fill}" + def matchrecognizemeasure_sql(self, expression: exp.MatchRecognizeMeasure) -> str: + window_frame = self.sql(expression, "window_frame") + window_frame = f"{window_frame} " if window_frame else "" + + this = self.sql(expression, "this") + + return f"{window_frame}{this}" + def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: partition = self.partition_by_sql(expression) order = self.sql(expression, "order") @@ -2297,6 +2317,19 @@ class Generator(metaclass=_Generator): return f"{self.seg('QUALIFY')}{self.sep()}{this}" def set_operations(self, expression: exp.Union) -> str: + if not self.OUTER_UNION_MODIFIERS: + limit = expression.args.get("limit") + order = expression.args.get("order") + + if limit or order: + select = exp.subquery(expression, "_l_0", copy=False).select("*", copy=False) + + if limit: + select = select.limit(limit.pop(), copy=False) + if order: + select = select.order_by(order.pop(), copy=False) + return self.sql(select) + sqls: t.List[str] = [] stack: t.List[t.Union[str, exp.Expression]] = [expression] @@ -2412,12 +2445,15 @@ class Generator(metaclass=_Generator): high = self.sql(expression, "high") return f"{this} BETWEEN {low} AND {high}" - def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset( + def bracket_offset_expressions(self, expression: exp.Bracket) -> t.List[exp.Expression]: + return apply_index_offset( expression.this, expression.expressions, self.dialect.INDEX_OFFSET - expression.args.get("offset", 0), ) + + def bracket_sql(self, expression: exp.Bracket) -> str: + expressions = self.bracket_offset_expressions(expression) expressions_sql = ", ".join(self.sql(e) for e in expressions) return f"{self.sql(expression, 'this')}[{expressions_sql}]" @@ -2486,7 +2522,7 @@ class Generator(metaclass=_Generator): args = args[1:] # Skip the delimiter if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): - args = [exp.cast(e, "text") for e in args] + args = [exp.cast(e, exp.DataType.Type.TEXT) for e in args] if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"): args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args] @@ -2670,7 +2706,7 @@ class Generator(metaclass=_Generator): is_global = " GLOBAL" if expression.args.get("is_global") else "" if query: - in_sql = self.wrap(self.sql(query)) + in_sql = self.sql(query) elif unnest: in_sql = self.in_unnest_op(unnest) elif field: @@ -2859,9 +2895,10 @@ class Generator(metaclass=_Generator): def comment_sql(self, expression: exp.Comment) -> str: this = self.sql(expression, "this") kind = expression.args["kind"] + materialized = " MATERIALIZED" if expression.args.get("materialized") else "" exists_sql = " IF EXISTS " if expression.args.get("exists") else " " expression_sql = self.sql(expression, "expression") - return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}" + return f"COMMENT{exists_sql}ON{materialized} {kind} {this} IS {expression_sql}" def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str: this = self.sql(expression, "this") @@ -3011,7 +3048,9 @@ class Generator(metaclass=_Generator): def dpipe_sql(self, expression: exp.DPipe) -> str: if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): - return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten())) + return self.func( + "CONCAT", *(exp.cast(e, exp.DataType.Type.TEXT) for e in expression.flatten()) + ) return self.binary(expression, "||") def div_sql(self, expression: exp.Div) -> str: @@ -3210,11 +3249,8 @@ class Generator(metaclass=_Generator): num_sqls = len(expressions) # These are calculated once in case we have the leading_comma / pretty option set, correspondingly - if self.pretty: - if self.leading_comma: - pad = " " * len(sep) - else: - stripped_sep = sep.strip() + if self.pretty and not self.leading_comma: + stripped_sep = sep.strip() result_sqls = [] for i, e in enumerate(expressions): @@ -3226,7 +3262,7 @@ class Generator(metaclass=_Generator): if self.pretty: if self.leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}") + result_sqls.append(f"{sep if i > 0 else ''}{prefix}{sql}{comments}") else: result_sqls.append( f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}" @@ -3314,17 +3350,17 @@ class Generator(metaclass=_Generator): if expression.args.get("format"): self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function") - return self.sql(exp.cast(expression.this, "text")) + return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT)) def tonumber_sql(self, expression: exp.ToNumber) -> str: if not self.SUPPORTS_TO_NUMBER: self.unsupported("Unsupported TO_NUMBER function") - return self.sql(exp.cast(expression.this, "double")) + return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) fmt = expression.args.get("format") if not fmt: self.unsupported("Conversion format is required for TO_NUMBER") - return self.sql(exp.cast(expression.this, "double")) + return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) return self.func("TO_NUMBER", expression.this, fmt) @@ -3495,14 +3531,14 @@ class Generator(metaclass=_Generator): if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME): return self.sql(this) - return self.sql(exp.cast(this, "time")) + return self.sql(exp.cast(this, exp.DataType.Type.TIME)) def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str: this = expression.this if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP): return self.sql(this) - return self.sql(exp.cast(this, "timestamp")) + return self.sql(exp.cast(this, exp.DataType.Type.TIMESTAMP)) def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str: this = expression.this @@ -3510,20 +3546,23 @@ class Generator(metaclass=_Generator): if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT): return self.sql( - exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date") + exp.cast( + exp.StrToTime(this=this, format=expression.args["format"]), + exp.DataType.Type.DATE, + ) ) if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE): return self.sql(this) - return self.sql(exp.cast(this, "date")) + return self.sql(exp.cast(this, exp.DataType.Type.DATE)) def unixdate_sql(self, expression: exp.UnixDate) -> str: return self.sql( exp.func( "DATEDIFF", expression.this, - exp.cast(exp.Literal.string("1970-01-01"), "date"), + exp.cast(exp.Literal.string("1970-01-01"), exp.DataType.Type.DATE), "day", ) ) diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index c85ef1c..8635f13 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -212,6 +212,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Month, exp.Week, exp.Year, + exp.Quarter, }, exp.DataType.Type.VARCHAR: { exp.ArrayConcat, @@ -504,7 +505,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): last_datatype = expr_type break - last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type) + if not expr_type.is_type(exp.DataType.Type.NULL, exp.DataType.Type.UNKNOWN): + last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type) self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index e4f8b57..4f64497 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -66,7 +66,7 @@ def qualify( """ schema = ensure_schema(schema, dialect=dialect) expression = normalize_identifiers(expression, dialect=dialect) - expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) + expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema, dialect=dialect) if isolate_tables: expression = isolate_table_selects(expression, schema=schema) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 027c32c..bd875a4 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -7,6 +7,7 @@ from sqlglot import alias, exp from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get, SingleValuedMapping +from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -652,8 +653,19 @@ class Resolver: if isinstance(source, exp.Table): columns = self.schema.column_names(source, only_visible) - elif isinstance(source, Scope) and isinstance(source.expression, exp.Values): - columns = source.expression.alias_column_names + elif isinstance(source, Scope) and isinstance(source.expression, (exp.Values, exp.Unnest)): + columns = source.expression.named_selects + + # in bigquery, unnest structs are automatically scoped as tables, so you can + # directly select a struct field in a query. + # this handles the case where the unnest is statically defined. + if self.schema.dialect == "bigquery": + expression = source.expression + annotate_types(expression) + + if expression.is_type(exp.DataType.Type.STRUCT): + for k in expression.type.expressions: # type: ignore + columns.append(k.name) else: columns = source.expression.named_selects diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index a034bf5..de18ca5 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -55,7 +55,7 @@ def qualify_tables( if not table.args.get("catalog") and table.args.get("db"): table.set("catalog", catalog) - if not isinstance(expression, exp.Query): + if (db or catalog) and not isinstance(expression, exp.Query): for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): if isinstance(node, exp.Table): _qualify(node) @@ -78,10 +78,10 @@ def qualify_tables( if pivots and not pivots[0].alias: pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) + table_aliases = {} + for name, source in scope.sources.items(): if isinstance(source, exp.Table): - _qualify(source) - pivots = pivots = source.args.get("pivots") if not source.alias: # Don't add the pivot's alias to the pivoted table, use the table's name instead @@ -91,6 +91,12 @@ def qualify_tables( # Mutates the source by attaching an alias to it alias(source, name or source.name or next_alias_name(), copy=False, table=True) + table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( + source.alias + ) + + _qualify(source) + if pivots and not pivots[0].alias: pivots[0].set( "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) @@ -127,4 +133,13 @@ def qualify_tables( # Mutates the table by attaching an alias to it alias(node, node.name, copy=False, table=True) + for column in scope.columns: + if column.db: + table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) + + if table_alias: + for p in exp.COLUMN_PARTS[1:]: + column.set(p, None) + column.set("table", table_alias) + return expression diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 073ced2..c589e24 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -600,7 +600,7 @@ def _traverse_ctes(scope): sources = {} for cte in scope.ctes: - recursive_scope = None + cte_name = cte.alias # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. # thus the recursive scope is the first section of the union. @@ -609,7 +609,7 @@ def _traverse_ctes(scope): union = cte.this if isinstance(union, exp.Union): - recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) child_scope = None @@ -623,15 +623,9 @@ def _traverse_ctes(scope): ): yield child_scope - alias = cte.alias - sources[alias] = child_scope - - if recursive_scope: - child_scope.add_source(alias, recursive_scope) - child_scope.cte_sources[alias] = recursive_scope - # append the final child_scope yielded if child_scope: + sources[cte_name] = child_scope scope.cte_scopes.append(child_scope) scope.sources.update(sources) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index b83abe6..be3ab6f 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -41,8 +41,6 @@ def unnest(select, parent_select, next_alias_name): return predicate = select.find_ancestor(exp.Condition) - alias = next_alias_name() - if ( not predicate or parent_select is not predicate.parent_select @@ -50,6 +48,10 @@ def unnest(select, parent_select, next_alias_name): ): return + if isinstance(select, exp.Union): + select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) + + alias = next_alias_name() clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) # This subquery returns a scalar and can just be converted to a cross join diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 91d8d13..99717f4 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -344,6 +344,7 @@ class Parser(metaclass=_Parser): TokenType.FINAL, TokenType.FORMAT, TokenType.FULL, + TokenType.IDENTIFIER, TokenType.IS, TokenType.ISNULL, TokenType.INTERVAL, @@ -852,6 +853,9 @@ class Parser(metaclass=_Parser): exp.DefaultColumnConstraint, this=self._parse_bitwise() ), "ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()), + "EPHEMERAL": lambda self: self.expression( + exp.EphemeralColumnConstraint, this=self._parse_bitwise() + ), "EXCLUDE": lambda self: self.expression( exp.ExcludeColumnConstraint, this=self._parse_index_params() ), @@ -1384,6 +1388,7 @@ class Parser(metaclass=_Parser): self._match(TokenType.ON) + materialized = self._match_text_seq("MATERIALIZED") kind = self._match_set(self.CREATABLES) and self._prev if not kind: return self._parse_as_command(start) @@ -1400,7 +1405,12 @@ class Parser(metaclass=_Parser): self._match(TokenType.IS) return self.expression( - exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists + exp.Comment, + this=this, + kind=kind.text, + expression=self._parse_string(), + exists=exists, + materialized=materialized, ) def _parse_to_table( @@ -2188,7 +2198,10 @@ class Parser(metaclass=_Parser): def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text - style = self._match_texts(("EXTENDED", "FORMATTED")) and self._prev.text.upper() + style = self._match_texts(("EXTENDED", "FORMATTED", "HISTORY")) and self._prev.text.upper() + if not self._match_set(self.ID_VAR_TOKENS, advance=False): + style = None + self._retreat(self._index - 1) this = self._parse_table(schema=True) properties = self._parse_properties() expressions = properties.expressions if properties else None @@ -2731,6 +2744,13 @@ class Parser(metaclass=_Parser): exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) ) + def _parse_match_recognize_measure(self) -> exp.MatchRecognizeMeasure: + return self.expression( + exp.MatchRecognizeMeasure, + window_frame=self._match_texts(("FINAL", "RUNNING")) and self._prev.text.upper(), + this=self._parse_expression(), + ) + def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: if not self._match(TokenType.MATCH_RECOGNIZE): return None @@ -2739,7 +2759,12 @@ class Parser(metaclass=_Parser): partition = self._parse_partition_by() order = self._parse_order() - measures = self._parse_expressions() if self._match_text_seq("MEASURES") else None + + measures = ( + self._parse_csv(self._parse_match_recognize_measure) + if self._match_text_seq("MEASURES") + else None + ) if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): rows = exp.var("ONE ROW PER MATCH") @@ -3444,10 +3469,12 @@ class Parser(metaclass=_Parser): if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None - elements = defaultdict(list) + elements: t.Dict[str, t.Any] = defaultdict(list) if self._match(TokenType.ALL): - return self.expression(exp.Group, all=True) + elements["all"] = True + elif self._match(TokenType.DISTINCT): + elements["all"] = False while True: expressions = self._parse_csv(self._parse_conjunction) @@ -3808,7 +3835,7 @@ class Parser(metaclass=_Parser): expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) if len(expressions) == 1 and isinstance(expressions[0], exp.Query): - this = self.expression(exp.In, this=this, query=expressions[0]) + this = self.expression(exp.In, this=this, query=expressions[0].subquery(copy=False)) else: this = self.expression(exp.In, this=this, expressions=expressions) @@ -4504,12 +4531,15 @@ class Parser(metaclass=_Parser): constraints: t.List[exp.Expression] = [] - if (not kind and self._match(TokenType.ALIAS)) or self._match_text_seq("ALIAS"): + if (not kind and self._match(TokenType.ALIAS)) or self._match_texts( + ("ALIAS", "MATERIALIZED") + ): + persisted = self._prev.text.upper() == "MATERIALIZED" constraints.append( self.expression( exp.ComputedColumnConstraint, this=self._parse_conjunction(), - persisted=self._match_text_seq("PERSISTED"), + persisted=persisted or self._match_text_seq("PERSISTED"), not_null=self._match_pair(TokenType.NOT, TokenType.NULL), ) ) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index f44c18c..2bd02ec 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -140,6 +140,26 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr return expression +def unqualify_unnest(expression: exp.Expression) -> exp.Expression: + """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" + from sqlglot.optimizer.scope import find_all_in_scope + + if isinstance(expression, exp.Select): + unnest_aliases = { + unnest.alias + for unnest in find_all_in_scope(expression, exp.Unnest) + if isinstance(unnest.parent, (exp.From, exp.Join)) + } + if unnest_aliases: + for column in expression.find_all(exp.Column): + if column.table in unnest_aliases: + column.set("table", None) + elif column.db in unnest_aliases: + column.set("db", None) + + return expression + + def unnest_to_explode(expression: exp.Expression) -> exp.Expression: """Convert cross join unnest into lateral view explode.""" if isinstance(expression, exp.Select): |