diff options
Diffstat (limited to 'sqlglot/dialects/bigquery.py')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 145 |
1 files changed, 56 insertions, 89 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index c0191b2..f867617 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,13 +12,14 @@ from sqlglot.dialects.dialect import ( binary_from_function, date_add_interval_sql, datestrtodate_sql, - format_time_lambda, + build_formatted_time, + filter_array_using_unnest, if_sql, inline_array_sql, max_or_greatest, min_or_least, no_ilike_sql, - parse_date_delta_with_interval, + build_date_delta_with_interval, regexp_replace_sql, rename_func, timestrtotime_sql, @@ -37,56 +38,33 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) + structs = [] alias = expression.args.get("alias") + for tup in expression.find_all(exp.Tuple): + field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions))) + expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)] + structs.append(exp.Struct(expressions=expressions)) - return self.unnest_sql( - exp.Unnest( - expressions=[ - exp.array( - *( - exp.Struct( - expressions=[ - exp.alias_(value, column_name) - for value, column_name in zip( - t.expressions, - ( - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(t.expressions))) - ), - ) - ] - ) - for t in expression.find_all(exp.Tuple) - ), - copy=False, - ) - ] - ) - ) + return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)])) def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): - this = f"{this.this} <{self.expressions(this)}>" + this = f"{self.sql(this, 'this')} <{self.expressions(this)}>" else: this = self.sql(this) return f"RETURNS {this}" def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: - kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) - - if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): + if expression.kind == "FUNCTION" and returns and returns.args.get("is_table"): expression.set("kind", "TABLE FUNCTION") if isinstance(expression.expression, (exp.Subquery, exp.Literal)): expression.set("expression", expression.expression.this) - return self.create_sql(expression) - return self.create_sql(expression) @@ -132,11 +110,10 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: if isinstance(select, exp.Alias) } - for e in group.expressions: - alias = aliases.get(e) - + for grouped in group.expressions: + alias = aliases.get(grouped) if alias: - e.replace(exp.column(alias)) + grouped.replace(exp.column(alias)) return expression @@ -168,24 +145,24 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: return expression -def _parse_parse_timestamp(args: t.List) -> exp.StrToTime: - this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) +def _build_parse_timestamp(args: t.List) -> exp.StrToTime: + this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) this.set("zone", seq_get(args, 2)) return this -def _parse_timestamp(args: t.List) -> exp.Timestamp: +def _build_timestamp(args: t.List) -> exp.Timestamp: timestamp = exp.Timestamp.from_arg_list(args) timestamp.set("with_tz", True) return timestamp -def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts: +def _build_date(args: t.List) -> exp.Date | exp.DateFromParts: expr_type = exp.DateFromParts if len(args) == 3 else exp.Date return expr_type.from_arg_list(args) -def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5: +def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5: # TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation arg = seq_get(args, 0) return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg) @@ -214,18 +191,20 @@ def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") + timestamp = expression.this + if scale in (None, exp.UnixToTime.SECONDS): - return f"TIMESTAMP_SECONDS({timestamp})" + return self.func("TIMESTAMP_SECONDS", timestamp) if scale == exp.UnixToTime.MILLIS: - return f"TIMESTAMP_MILLIS({timestamp})" + return self.func("TIMESTAMP_MILLIS", timestamp) if scale == exp.UnixToTime.MICROS: - return f"TIMESTAMP_MICROS({timestamp})" + return self.func("TIMESTAMP_MICROS", timestamp) - return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))" + unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64") + return self.func("TIMESTAMP_SECONDS", unix_seconds) -def _parse_time(args: t.List) -> exp.Func: +def _build_time(args: t.List) -> exp.Func: if len(args) == 1: return exp.TsOrDsToTime(this=args[0]) if len(args) == 3: @@ -323,6 +302,7 @@ class BigQuery(Dialect): "BYTES": TokenType.BINARY, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "DECLARE": TokenType.COMMAND, + "EXCEPTION": TokenType.COMMAND, "FLOAT64": TokenType.DOUBLE, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, "MODEL": TokenType.MODEL, @@ -340,15 +320,15 @@ class BigQuery(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE": _parse_date, - "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), - "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DATE": _build_date, + "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), + "DATE_SUB": build_date_delta_with_interval(exp.DateSub), "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), ), - "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), - "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), + "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd), + "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub), "DIV": binary_from_function(exp.IntDiv), "FORMAT_DATE": lambda args: exp.TimeToStr( this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0) @@ -358,11 +338,11 @@ class BigQuery(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$") ), "MD5": exp.MD5Digest.from_arg_list, - "TO_HEX": _parse_to_hex, - "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")( + "TO_HEX": _build_to_hex, + "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")( [seq_get(args, 1), seq_get(args, 0)] ), - "PARSE_TIMESTAMP": _parse_parse_timestamp, + "PARSE_TIMESTAMP": _build_parse_timestamp, "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), @@ -378,12 +358,12 @@ class BigQuery(Dialect): this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(","), ), - "TIME": _parse_time, - "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), - "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), - "TIMESTAMP": _parse_timestamp, - "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), - "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), + "TIME": _build_time, + "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd), + "TIME_SUB": build_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP": _build_timestamp, + "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd), + "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub), "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( this=seq_get(args, 0), scale=exp.UnixToTime.MICROS ), @@ -424,7 +404,7 @@ class BigQuery(Dialect): } RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy() - RANGE_PARSERS.pop(TokenType.OVERLAPS, None) + RANGE_PARSERS.pop(TokenType.OVERLAPS) NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN} @@ -551,6 +531,7 @@ class BigQuery(Dialect): NULL_ORDERING_SUPPORTED = False IGNORE_NULLS_IN_FUNC = True JSON_PATH_SINGLE_QUOTE_ESCAPE = True + CAN_IMPLEMENT_ARRAY_ANY = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -558,6 +539,7 @@ class BigQuery(Dialect): exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArrayContains: _array_contains_sql, + exp.ArrayFilter: filter_array_using_unnest, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), exp.CollateProperty: lambda self, e: ( @@ -565,12 +547,14 @@ class BigQuery(Dialect): if e.args.get("default") else f"COLLATE {self.sql(e, 'this')}" ), + exp.Commit: lambda *_: "COMMIT TRANSACTION", exp.CountIf: rename_func("COUNTIF"), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), - exp.DateDiff: lambda self, - e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", e.this, e.expression, e.unit or "DAY" + ), exp.DateFromParts: rename_func("DATE"), exp.DateStrToDate: datestrtodate_sql, exp.DateSub: date_add_interval_sql("DATE", "SUB"), @@ -602,6 +586,7 @@ class BigQuery(Dialect): exp.RegexpReplace: regexp_replace_sql, exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.ReturnsProperty: _returnsproperty_sql, + exp.Rollback: lambda *_: "ROLLBACK TRANSACTION", exp.Select: transforms.preprocess( [ transforms.explode_to_unnest(), @@ -617,8 +602,7 @@ class BigQuery(Dialect): exp.StabilityProperty: lambda self, e: ( "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" ), - exp.StrToDate: lambda self, - e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", + exp.StrToDate: lambda self, e: self.func("PARSE_DATE", self.format_time(e), e.this), exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") ), @@ -629,6 +613,7 @@ class BigQuery(Dialect): exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"), exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.Transaction: lambda *_: "BEGIN TRANSACTION", exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsDiff: _ts_or_ds_diff_sql, @@ -778,12 +763,8 @@ class BigQuery(Dialect): } def timetostr_sql(self, expression: exp.TimeToStr) -> str: - if isinstance(expression.this, exp.TsOrDsToDate): - this: exp.Expression = expression.this - else: - this = expression - - return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})" + this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression + return self.func("FORMAT_DATE", self.format_time(expression), this.this) def struct_sql(self, expression: exp.Struct) -> str: args = [] @@ -820,11 +801,6 @@ class BigQuery(Dialect): def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") - def cte_sql(self, expression: exp.CTE) -> str: - if expression.alias_column_names: - self.unsupported("Column names in CTE definition are not supported.") - return super().cte_sql(expression) - def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) if isinstance(first_arg, exp.Subqueryable): @@ -862,25 +838,16 @@ class BigQuery(Dialect): return f"{this}[{expressions_sql}]" - def transaction_sql(self, *_) -> str: - return "BEGIN TRANSACTION" - - def commit_sql(self, *_) -> str: - return "COMMIT TRANSACTION" - - def rollback_sql(self, *_) -> str: - return "ROLLBACK TRANSACTION" - def in_unnest_op(self, expression: exp.Unnest) -> str: return self.sql(expression) def except_op(self, expression: exp.Except) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" def intersect_op(self, expression: exp.Intersect) -> str: - if not expression.args.get("distinct", False): + if not expression.args.get("distinct"): self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" |