diff options
Diffstat (limited to 'sqlglot/dialects/bigquery.py')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 54 |
1 files changed, 42 insertions, 12 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 90ae229..6a19b46 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,6 +2,8 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -14,8 +16,10 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +E = t.TypeVar("E", bound=exp.Expression) + -def _date_add(expression_class): +def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: def func(args): interval = seq_get(args, 1) return expression_class( @@ -27,26 +31,26 @@ def _date_add(expression_class): return func -def _date_trunc(args): +def _date_trunc(args: t.Sequence) -> exp.Expression: unit = seq_get(args, 1) if isinstance(unit, exp.Column): unit = exp.Var(this=unit.name) return exp.DateTrunc(this=seq_get(args, 0), expression=unit) -def _date_add_sql(data_type, kind): +def _date_add_sql( + data_type: str, kind: str +) -> t.Callable[[generator.Generator, exp.Expression], str]: def func(self, expression): this = self.sql(expression, "this") - unit = self.sql(expression, "unit") or "'day'" - expression = self.sql(expression, "expression") - return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})" + return f"{data_type}_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=expression.args.get('unit') or exp.Literal.string('day')))})" return func -def _derived_table_values_to_unnest(self, expression): +def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: if not isinstance(expression.unnest().parent, exp.From): - expression = transforms.remove_precision_parameterized_types(expression) + expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression)) return self.values_sql(expression) rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] structs = [] @@ -60,7 +64,7 @@ def _derived_table_values_to_unnest(self, expression): return self.unnest_sql(unnest_exp) -def _returnsproperty_sql(self, expression): +def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): this = f"{this.this} <{self.expressions(this)}>" @@ -69,8 +73,8 @@ def _returnsproperty_sql(self, expression): return f"RETURNS {this}" -def _create_sql(self, expression): - kind = expression.args.get("kind") +def _create_sql(self: generator.Generator, expression: exp.Create) -> str: + kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"): expression = expression.copy() @@ -89,6 +93,29 @@ def _create_sql(self, expression): return self.create_sql(expression) +def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: + """Remove references to unnest table aliases since bigquery doesn't allow them. + + These are added by the optimizer's qualify_column step. + """ + if isinstance(expression, exp.Select): + unnests = { + unnest.alias + for unnest in expression.args.get("from", exp.From(expressions=[])).expressions + if isinstance(unnest, exp.Unnest) and unnest.alias + } + + if unnests: + expression = expression.copy() + + for select in expression.expressions: + for column in select.find_all(exp.Column): + if column.table in unnests: + column.set("table", None) + + return expression + + class BigQuery(Dialect): unnest_column_only = True time_mapping = { @@ -110,7 +137,7 @@ class BigQuery(Dialect): ] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPES = ["\\"] + STRING_ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { @@ -190,6 +217,9 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Select: transforms.preprocess( + [_unqualify_unnest], transforms.delegate("select_sql") + ), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), |