diff options
Diffstat (limited to 'sqlglot/dialects/bigquery.py')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 86 |
1 files changed, 74 insertions, 12 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 71977dd..d763ed0 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + json_keyvalue_comma_sql, max_or_greatest, min_or_least, no_ilike_sql, @@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot") def _date_add_sql( data_type: str, kind: str -) -> t.Callable[[generator.Generator, exp.Expression], str]: - def func(self, expression): +) -> t.Callable[[BigQuery.Generator, exp.Expression], str]: + def func(self: BigQuery.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") @@ -40,7 +41,7 @@ def _date_add_sql( return func -def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: +def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str: if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) @@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)])) -def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: +def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): this = f"{this.this} <{self.expressions(this)}>" @@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope return f"RETURNS {this}" -def _create_sql(self: generator.Generator, expression: exp.Create) -> str: +def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) @@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: These are added by the optimizer's qualify_column step. """ - from sqlglot.optimizer.scope import Scope + from sqlglot.optimizer.scope import find_all_in_scope if isinstance(expression, exp.Select): - for unnest in expression.find_all(exp.Unnest): - if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias: - for column in Scope(expression).find_all(exp.Column): - if column.table == unnest.alias: - column.set("table", None) + unnest_aliases = { + unnest.alias + for unnest in find_all_in_scope(expression, exp.Unnest) + if isinstance(unnest.parent, (exp.From, exp.Join)) + } + if unnest_aliases: + for column in expression.find_all(exp.Column): + if column.table in unnest_aliases: + column.set("table", None) + elif column.db in unnest_aliases: + column.set("db", None) return expression @@ -261,6 +268,7 @@ class BigQuery(Dialect): "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } KEYWORDS.pop("DIV") @@ -270,6 +278,8 @@ class BigQuery(Dialect): LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": _parse_date, @@ -299,6 +309,8 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), + "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), + "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), "SPLIT": lambda args: exp.Split( # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split this=seq_get(args, 0), @@ -346,7 +358,7 @@ class BigQuery(Dialect): } def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: - this = super()._parse_table_part(schema=schema) + this = super()._parse_table_part(schema=schema) or self._parse_number() # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names if isinstance(this, exp.Identifier): @@ -356,6 +368,17 @@ class BigQuery(Dialect): table_name += f"-{self._prev.text}" this = exp.Identifier(this=table_name, quoted=this.args.get("quoted")) + elif isinstance(this, exp.Literal): + table_name = this.name + + if ( + self._curr + and self._prev.end == self._curr.start - 1 + and self._parse_var(any_token=True) + ): + table_name += self._prev.text + + this = exp.Identifier(this=table_name, quoted=True) return this @@ -374,6 +397,27 @@ class BigQuery(Dialect): return table + def _parse_json_object(self) -> exp.JSONObject: + json_object = super()._parse_json_object() + array_kv_pair = seq_get(json_object.expressions, 0) + + # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation + # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2 + if ( + array_kv_pair + and isinstance(array_kv_pair.this, exp.Array) + and isinstance(array_kv_pair.expression, exp.Array) + ): + keys = array_kv_pair.this.expressions + values = array_kv_pair.expression.expressions + + json_object.set( + "expressions", + [exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)], + ) + + return json_object + class Generator(generator.Generator): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False @@ -383,6 +427,7 @@ class BigQuery(Dialect): LIMIT_FETCH = "LIMIT" RENAME_TABLE_WITH_DB = False ESCAPE_LINE_BREAK = True + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -405,6 +450,7 @@ class BigQuery(Dialect): exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), + exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), @@ -428,6 +474,9 @@ class BigQuery(Dialect): _alias_ordered_group, ] ), + exp.SHA2: lambda self, e: self.func( + f"SHA256" if e.text("length") == "256" else "SHA512", e.this + ), exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -591,6 +640,13 @@ class BigQuery(Dialect): return super().attimezone_sql(expression) + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals + if expression.is_type("json"): + return f"JSON {self.sql(expression, 'this')}" + + return super().cast_sql(expression, safe_prefix=safe_prefix) + def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") @@ -630,3 +686,9 @@ class BigQuery(Dialect): def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("OPTIONS")) + + def version_sql(self, expression: exp.Version) -> str: + if expression.name == "TIMESTAMP": + expression = expression.copy() + expression.set("this", "SYSTEM_TIME") + return super().version_sql(expression) |