diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-29 13:02:29 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-29 13:02:29 +0000 |
commit | 9b39dac84e82bf473216939e50b8836170f01d23 (patch) | |
tree | 9b405bc86ef7e2ea28cddc6b787ed70355cf7fce /sqlglot/dialects/bigquery.py | |
parent | Releasing debian version 16.4.2-1. (diff) | |
download | sqlglot-9b39dac84e82bf473216939e50b8836170f01d23.tar.xz sqlglot-9b39dac84e82bf473216939e50b8836170f01d23.zip |
Merging upstream version 16.7.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/bigquery.py')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 205 |
1 files changed, 193 insertions, 12 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 52d4a88..8786063 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import re import typing as t @@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType +logger = logging.getLogger("sqlglot") + def _date_add_sql( data_type: str, kind: str @@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: return expression +# https://issuetracker.google.com/issues/162294746 +# workaround for bigquery bug when grouping by an expression and then ordering +# WITH x AS (SELECT 1 y) +# SELECT y + 1 z +# FROM x +# GROUP BY x + 1 +# ORDER by z +def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + group = expression.args.get("group") + order = expression.args.get("order") + + if group and order: + aliases = { + select.this: select.args["alias"] + for select in expression.selects + if isinstance(select, exp.Alias) + } + + for e in group.expressions: + alias = aliases.get(e) + + if alias: + e.replace(exp.column(alias)) + + return expression + + +def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: + """BigQuery doesn't allow column names when defining a CTE, so we try to push them down.""" + if isinstance(expression, exp.CTE) and expression.alias_column_names: + cte_query = expression.this + + if cte_query.is_star: + logger.warning( + "Can't push down CTE column names for star queries. Run the query through" + " the optimizer or use 'qualify' to expand the star projections first." + ) + return expression + + column_names = expression.alias_column_names + expression.args["alias"].set("columns", None) + + for name, select in zip(column_names, cte_query.selects): + to_replace = select + + if isinstance(select, exp.Alias): + select = select.this + + # Inner aliases are shadowed by the CTE column names + to_replace.replace(exp.alias_(select, name)) + + return expression + + class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + # bigquery udfs are case sensitive + NORMALIZE_FUNCTIONS = False + TIME_MAPPING = { "%D": "%m/%d/%y", } @@ -135,12 +196,16 @@ class BigQuery(Dialect): # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). # The following check is essentially a heuristic to detect tables based on whether or # not they're qualified. - if ( - isinstance(expression, exp.Identifier) - and not (isinstance(expression.parent, exp.Table) and expression.parent.db) - and not expression.meta.get("is_table") - ): - expression.set("this", expression.this.lower()) + if isinstance(expression, exp.Identifier): + parent = expression.parent + + while isinstance(parent, exp.Dot): + parent = parent.parent + + if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get( + "is_table" + ): + expression.set("this", expression.this.lower()) return expression @@ -298,10 +363,8 @@ class BigQuery(Dialect): **generator.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySize: rename_func("ARRAY_LENGTH"), - exp.AtTimeZone: lambda self, e: self.func( - "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) - ), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), + exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), @@ -325,7 +388,12 @@ class BigQuery(Dialect): ), exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.Select: transforms.preprocess( - [_unqualify_unnest, transforms.eliminate_distinct_on] + [ + transforms.explode_to_unnest, + _unqualify_unnest, + transforms.eliminate_distinct_on, + _alias_ordered_group, + ] ), exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", @@ -334,7 +402,6 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -378,7 +445,121 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"} + # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords + RESERVED_KEYWORDS = { + *generator.Generator.RESERVED_KEYWORDS, + "all", + "and", + "any", + "array", + "as", + "asc", + "assert_rows_modified", + "at", + "between", + "by", + "case", + "cast", + "collate", + "contains", + "create", + "cross", + "cube", + "current", + "default", + "define", + "desc", + "distinct", + "else", + "end", + "enum", + "escape", + "except", + "exclude", + "exists", + "extract", + "false", + "fetch", + "following", + "for", + "from", + "full", + "group", + "grouping", + "groups", + "hash", + "having", + "if", + "ignore", + "in", + "inner", + "intersect", + "interval", + "into", + "is", + "join", + "lateral", + "left", + "like", + "limit", + "lookup", + "merge", + "natural", + "new", + "no", + "not", + "null", + "nulls", + "of", + "on", + "or", + "order", + "outer", + "over", + "partition", + "preceding", + "proto", + "qualify", + "range", + "recursive", + "respect", + "right", + "rollup", + "rows", + "select", + "set", + "some", + "struct", + "tablesample", + "then", + "to", + "treat", + "true", + "unbounded", + "union", + "unnest", + "using", + "when", + "where", + "window", + "with", + "within", + } + + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: + if not isinstance(expression.parent, exp.Cast): + return self.func( + "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone")) + ) + return super().attimezone_sql(expression) + + def trycast_sql(self, expression: exp.TryCast) -> str: + return self.cast_sql(expression, safe_prefix="SAFE_") + + def cte_sql(self, expression: exp.CTE) -> str: + if expression.alias_column_names: + self.unsupported("Column names in CTE definition are not supported.") + return super().cte_sql(expression) def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) |