summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/bigquery.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/bigquery.py')
-rw-r--r--sqlglot/dialects/bigquery.py145
1 files changed, 56 insertions, 89 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index c0191b2..f867617 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -12,13 +12,14 @@ from sqlglot.dialects.dialect import (
binary_from_function,
date_add_interval_sql,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
+ filter_array_using_unnest,
if_sql,
inline_array_sql,
max_or_greatest,
min_or_least,
no_ilike_sql,
- parse_date_delta_with_interval,
+ build_date_delta_with_interval,
regexp_replace_sql,
rename_func,
timestrtotime_sql,
@@ -37,56 +38,33 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
if not expression.find_ancestor(exp.From, exp.Join):
return self.values_sql(expression)
+ structs = []
alias = expression.args.get("alias")
+ for tup in expression.find_all(exp.Tuple):
+ field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions)))
+ expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)]
+ structs.append(exp.Struct(expressions=expressions))
- return self.unnest_sql(
- exp.Unnest(
- expressions=[
- exp.array(
- *(
- exp.Struct(
- expressions=[
- exp.alias_(value, column_name)
- for value, column_name in zip(
- t.expressions,
- (
- alias.columns
- if alias and alias.columns
- else (f"_c{i}" for i in range(len(t.expressions)))
- ),
- )
- ]
- )
- for t in expression.find_all(exp.Tuple)
- ),
- copy=False,
- )
- ]
- )
- )
+ return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)]))
def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str:
this = expression.this
if isinstance(this, exp.Schema):
- this = f"{this.this} <{self.expressions(this)}>"
+ this = f"{self.sql(this, 'this')} <{self.expressions(this)}>"
else:
this = self.sql(this)
return f"RETURNS {this}"
def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
- kind = expression.args["kind"]
returns = expression.find(exp.ReturnsProperty)
-
- if kind.upper() == "FUNCTION" and returns and returns.args.get("is_table"):
+ if expression.kind == "FUNCTION" and returns and returns.args.get("is_table"):
expression.set("kind", "TABLE FUNCTION")
if isinstance(expression.expression, (exp.Subquery, exp.Literal)):
expression.set("expression", expression.expression.this)
- return self.create_sql(expression)
-
return self.create_sql(expression)
@@ -132,11 +110,10 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression:
if isinstance(select, exp.Alias)
}
- for e in group.expressions:
- alias = aliases.get(e)
-
+ for grouped in group.expressions:
+ alias = aliases.get(grouped)
if alias:
- e.replace(exp.column(alias))
+ grouped.replace(exp.column(alias))
return expression
@@ -168,24 +145,24 @@ def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression:
return expression
-def _parse_parse_timestamp(args: t.List) -> exp.StrToTime:
- this = format_time_lambda(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
+def _build_parse_timestamp(args: t.List) -> exp.StrToTime:
+ this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)])
this.set("zone", seq_get(args, 2))
return this
-def _parse_timestamp(args: t.List) -> exp.Timestamp:
+def _build_timestamp(args: t.List) -> exp.Timestamp:
timestamp = exp.Timestamp.from_arg_list(args)
timestamp.set("with_tz", True)
return timestamp
-def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
+def _build_date(args: t.List) -> exp.Date | exp.DateFromParts:
expr_type = exp.DateFromParts if len(args) == 3 else exp.Date
return expr_type.from_arg_list(args)
-def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
+def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5:
# TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation
arg = seq_get(args, 0)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
@@ -214,18 +191,20 @@ def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) ->
def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
- timestamp = self.sql(expression, "this")
+ timestamp = expression.this
+
if scale in (None, exp.UnixToTime.SECONDS):
- return f"TIMESTAMP_SECONDS({timestamp})"
+ return self.func("TIMESTAMP_SECONDS", timestamp)
if scale == exp.UnixToTime.MILLIS:
- return f"TIMESTAMP_MILLIS({timestamp})"
+ return self.func("TIMESTAMP_MILLIS", timestamp)
if scale == exp.UnixToTime.MICROS:
- return f"TIMESTAMP_MICROS({timestamp})"
+ return self.func("TIMESTAMP_MICROS", timestamp)
- return f"TIMESTAMP_SECONDS(CAST({timestamp} / POW(10, {scale}) AS INT64))"
+ unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64")
+ return self.func("TIMESTAMP_SECONDS", unix_seconds)
-def _parse_time(args: t.List) -> exp.Func:
+def _build_time(args: t.List) -> exp.Func:
if len(args) == 1:
return exp.TsOrDsToTime(this=args[0])
if len(args) == 3:
@@ -323,6 +302,7 @@ class BigQuery(Dialect):
"BYTES": TokenType.BINARY,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"DECLARE": TokenType.COMMAND,
+ "EXCEPTION": TokenType.COMMAND,
"FLOAT64": TokenType.DOUBLE,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
"MODEL": TokenType.MODEL,
@@ -340,15 +320,15 @@ class BigQuery(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE": _parse_date,
- "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
- "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DATE": _build_date,
+ "DATE_ADD": build_date_delta_with_interval(exp.DateAdd),
+ "DATE_SUB": build_date_delta_with_interval(exp.DateSub),
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(str(seq_get(args, 1))),
this=seq_get(args, 0),
),
- "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd),
- "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
+ "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd),
+ "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub),
"DIV": binary_from_function(exp.IntDiv),
"FORMAT_DATE": lambda args: exp.TimeToStr(
this=exp.TsOrDsToDate(this=seq_get(args, 1)), format=seq_get(args, 0)
@@ -358,11 +338,11 @@ class BigQuery(Dialect):
this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string("$")
),
"MD5": exp.MD5Digest.from_arg_list,
- "TO_HEX": _parse_to_hex,
- "PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
+ "TO_HEX": _build_to_hex,
+ "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
- "PARSE_TIMESTAMP": _parse_parse_timestamp,
+ "PARSE_TIMESTAMP": _build_parse_timestamp,
"REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
this=seq_get(args, 0),
@@ -378,12 +358,12 @@ class BigQuery(Dialect):
this=seq_get(args, 0),
expression=seq_get(args, 1) or exp.Literal.string(","),
),
- "TIME": _parse_time,
- "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd),
- "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub),
- "TIMESTAMP": _parse_timestamp,
- "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd),
- "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub),
+ "TIME": _build_time,
+ "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd),
+ "TIME_SUB": build_date_delta_with_interval(exp.TimeSub),
+ "TIMESTAMP": _build_timestamp,
+ "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd),
+ "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub),
"TIMESTAMP_MICROS": lambda args: exp.UnixToTime(
this=seq_get(args, 0), scale=exp.UnixToTime.MICROS
),
@@ -424,7 +404,7 @@ class BigQuery(Dialect):
}
RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy()
- RANGE_PARSERS.pop(TokenType.OVERLAPS, None)
+ RANGE_PARSERS.pop(TokenType.OVERLAPS)
NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN}
@@ -551,6 +531,7 @@ class BigQuery(Dialect):
NULL_ORDERING_SUPPORTED = False
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
+ CAN_IMPLEMENT_ARRAY_ANY = True
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
@@ -558,6 +539,7 @@ class BigQuery(Dialect):
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayContains: _array_contains_sql,
+ exp.ArrayFilter: filter_array_using_unnest,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
exp.CollateProperty: lambda self, e: (
@@ -565,12 +547,14 @@ class BigQuery(Dialect):
if e.args.get("default")
else f"COLLATE {self.sql(e, 'this')}"
),
+ exp.Commit: lambda *_: "COMMIT TRANSACTION",
exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: date_add_interval_sql("DATE", "ADD"),
- exp.DateDiff: lambda self,
- e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
+ exp.DateDiff: lambda self, e: self.func(
+ "DATE_DIFF", e.this, e.expression, e.unit or "DAY"
+ ),
exp.DateFromParts: rename_func("DATE"),
exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: date_add_interval_sql("DATE", "SUB"),
@@ -602,6 +586,7 @@ class BigQuery(Dialect):
exp.RegexpReplace: regexp_replace_sql,
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
exp.ReturnsProperty: _returnsproperty_sql,
+ exp.Rollback: lambda *_: "ROLLBACK TRANSACTION",
exp.Select: transforms.preprocess(
[
transforms.explode_to_unnest(),
@@ -617,8 +602,7 @@ class BigQuery(Dialect):
exp.StabilityProperty: lambda self, e: (
"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
),
- exp.StrToDate: lambda self,
- e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
+ exp.StrToDate: lambda self, e: self.func("PARSE_DATE", self.format_time(e), e.this),
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
),
@@ -629,6 +613,7 @@ class BigQuery(Dialect):
exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"),
exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
+ exp.Transaction: lambda *_: "BEGIN TRANSACTION",
exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression),
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsDiff: _ts_or_ds_diff_sql,
@@ -778,12 +763,8 @@ class BigQuery(Dialect):
}
def timetostr_sql(self, expression: exp.TimeToStr) -> str:
- if isinstance(expression.this, exp.TsOrDsToDate):
- this: exp.Expression = expression.this
- else:
- this = expression
-
- return f"FORMAT_DATE({self.format_time(expression)}, {self.sql(this, 'this')})"
+ this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression
+ return self.func("FORMAT_DATE", self.format_time(expression), this.this)
def struct_sql(self, expression: exp.Struct) -> str:
args = []
@@ -820,11 +801,6 @@ class BigQuery(Dialect):
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
- def cte_sql(self, expression: exp.CTE) -> str:
- if expression.alias_column_names:
- self.unsupported("Column names in CTE definition are not supported.")
- return super().cte_sql(expression)
-
def array_sql(self, expression: exp.Array) -> str:
first_arg = seq_get(expression.expressions, 0)
if isinstance(first_arg, exp.Subqueryable):
@@ -862,25 +838,16 @@ class BigQuery(Dialect):
return f"{this}[{expressions_sql}]"
- def transaction_sql(self, *_) -> str:
- return "BEGIN TRANSACTION"
-
- def commit_sql(self, *_) -> str:
- return "COMMIT TRANSACTION"
-
- def rollback_sql(self, *_) -> str:
- return "ROLLBACK TRANSACTION"
-
def in_unnest_op(self, expression: exp.Unnest) -> str:
return self.sql(expression)
def except_op(self, expression: exp.Except) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
def intersect_op(self, expression: exp.Intersect) -> str:
- if not expression.args.get("distinct", False):
+ if not expression.args.get("distinct"):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"