summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-07-10 05:36:29 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-07-10 05:36:29 +0000
commit49af28576db02470fe1d2de04e3901309b60c2e4 (patch)
tree63e63864ce7f62e9288ccb9ee438eddba08c6d49 /sqlglot
parentReleasing debian version 17.2.0-1. (diff)
downloadsqlglot-49af28576db02470fe1d2de04e3901309b60c2e4.tar.xz
sqlglot-49af28576db02470fe1d2de04e3901309b60c2e4.zip
Merging upstream version 17.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/bigquery.py36
-rw-r--r--sqlglot/dialects/hive.py1
-rw-r--r--sqlglot/dialects/postgres.py1
-rw-r--r--sqlglot/dialects/spark.py6
-rw-r--r--sqlglot/dialects/spark2.py3
-rw-r--r--sqlglot/dialects/sqlite.py2
-rw-r--r--sqlglot/expressions.py16
-rw-r--r--sqlglot/generator.py10
-rw-r--r--sqlglot/optimizer/qualify_tables.py34
-rw-r--r--sqlglot/optimizer/scope.py8
-rw-r--r--sqlglot/parser.py30
11 files changed, 101 insertions, 46 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 82162b4..35892f7 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -174,6 +174,12 @@ def _parse_date(args: t.List) -> exp.Date | exp.DateFromParts:
return expr_type.from_arg_list(args)
+def _parse_to_hex(args: t.List) -> exp.Hex | exp.MD5:
+ # TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation
+ arg = seq_get(args, 0)
+ return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
+
+
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True
@@ -275,6 +281,8 @@ class BigQuery(Dialect):
"DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"GENERATE_ARRAY": exp.GenerateSeries.from_arg_list,
+ "MD5": exp.MD5Digest.from_arg_list,
+ "TO_HEX": _parse_to_hex,
"PARSE_DATE": lambda args: format_time_lambda(exp.StrToDate, "bigquery")(
[seq_get(args, 1), seq_get(args, 0)]
),
@@ -379,22 +387,27 @@ class BigQuery(Dialect):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
+ exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
+ exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.DateFromParts: rename_func("DATE"),
+ exp.DateStrToDate: datestrtodate_sql,
exp.DateSub: _date_add_sql("DATE", "SUB"),
exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
- exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
- exp.DateStrToDate: datestrtodate_sql,
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
- exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
+ exp.Hex: rename_func("TO_HEX"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
+ exp.JSONFormat: rename_func("TO_JSON_STRING"),
exp.Max: max_or_greatest,
+ exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)),
+ exp.MD5Digest: rename_func("MD5"),
exp.Min: min_or_least,
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.RegexpExtract: lambda self, e: self.func(
"REGEXP_EXTRACT",
e.this,
@@ -403,6 +416,7 @@ class BigQuery(Dialect):
e.args.get("occurrence"),
),
exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
+ exp.ReturnsProperty: _returnsproperty_sql,
exp.Select: transforms.preprocess(
[
transforms.explode_to_unnest,
@@ -411,6 +425,9 @@ class BigQuery(Dialect):
_alias_ordered_group,
]
),
+ exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
+ if e.name == "IMMUTABLE"
+ else "NOT DETERMINISTIC",
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
@@ -420,17 +437,12 @@ class BigQuery(Dialect):
exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
+ exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"),
- exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
- exp.VariancePop: rename_func("VAR_POP"),
+ exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"),
+ exp.Unhex: rename_func("FROM_HEX"),
exp.Values: _derived_table_values_to_unnest,
- exp.ReturnsProperty: _returnsproperty_sql,
- exp.Create: _create_sql,
- exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression),
- exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
- if e.name == "IMMUTABLE"
- else "NOT DETERMINISTIC",
+ exp.VariancePop: rename_func("VAR_POP"),
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 1abc0f4..5762efb 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -357,6 +357,7 @@ class Hive(Dialect):
exp.Left: left_to_substring_sql,
exp.Map: var_map_sql,
exp.Max: max_or_greatest,
+ exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.VarMap: var_map_sql,
exp.Create: create_with_partitions_sql,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 6d78a07..7706456 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -263,6 +263,7 @@ class Postgres(Dialect):
"DO": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"JSONB": TokenType.JSONB,
+ "MONEY": TokenType.MONEY,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
"RESET": TokenType.COMMAND,
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 7a7ee01..73f4370 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -41,6 +41,12 @@ class Spark(Spark2):
}
class Generator(Spark2.Generator):
+ TYPE_MAPPING = {
+ **Spark2.Generator.TYPE_MAPPING,
+ exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
+ exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)",
+ exp.DataType.Type.UNIQUEIDENTIFIER: "STRING",
+ }
TRANSFORMS = Spark2.Generator.TRANSFORMS.copy()
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index afe2482..f909e8c 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -177,9 +177,6 @@ class Spark2(Hive):
TYPE_MAPPING = {
**Hive.Generator.TYPE_MAPPING,
- exp.DataType.Type.TINYINT: "BYTE",
- exp.DataType.Type.SMALLINT: "SHORT",
- exp.DataType.Type.BIGINT: "LONG",
}
PROPERTIES_LOCATION = {
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 5ded6df..90b774e 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -192,7 +192,7 @@ class SQLite(Dialect):
if len(expression.expressions) > 1:
return rename_func("MIN")(self, expression)
- return self.expressions(expression)
+ return self.sql(expression, "this")
def transaction_sql(self, expression: exp.Transaction) -> str:
this = expression.this
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index fdf02c8..242e66c 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -274,12 +274,16 @@ class Expression(metaclass=_Expression):
def set(self, arg_key: str, value: t.Any) -> None:
"""
- Sets `arg_key` to `value`.
+ Sets arg_key to value.
Args:
- arg_key (str): name of the expression arg.
+ arg_key: name of the expression arg.
value: value to set the arg to.
"""
+ if value is None:
+ self.args.pop(arg_key, None)
+ return
+
self.args[arg_key] = value
self._set_parent(arg_key, value)
@@ -2278,6 +2282,7 @@ class Table(Expression):
"pivots": False,
"hints": False,
"system_time": False,
+ "wrapped": False,
}
@property
@@ -4249,7 +4254,7 @@ class JSONArrayContains(Binary, Predicate, Func):
class Least(Func):
- arg_types = {"expressions": False}
+ arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -4342,6 +4347,11 @@ class MD5(Func):
_sql_names = ["MD5"]
+# Represents the variant of the MD5 function that returns a binary value
+class MD5Digest(Func):
+ _sql_names = ["MD5_DIGEST"]
+
+
class Min(AggFunc):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 1ce2aaa..4ac988f 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1215,7 +1215,8 @@ class Generator:
system_time = expression.args.get("system_time")
system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
- return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
+ sql = f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}"
+ return f"({sql})" if expression.args.get("wrapped") else sql
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@@ -2289,11 +2290,14 @@ class Generator:
def function_fallback_sql(self, expression: exp.Func) -> str:
args = []
- for arg_value in expression.args.values():
+
+ for key in expression.arg_types:
+ arg_value = expression.args.get(key)
+
if isinstance(arg_value, list):
for value in arg_value:
args.append(value)
- else:
+ elif arg_value is not None:
args.append(arg_value)
return self.func(expression.sql_name(), *args)
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 9c931d6..af8c716 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -15,8 +15,7 @@ def qualify_tables(
schema: t.Optional[Schema] = None,
) -> E:
"""
- Rewrite sqlglot AST to have fully qualified tables. Additionally, this
- replaces "join constructs" (*) by equivalent SELECT * subqueries.
+ Rewrite sqlglot AST to have fully qualified, unnested tables.
Examples:
>>> import sqlglot
@@ -24,9 +23,18 @@ def qualify_tables(
>>> qualify_tables(expression, db="db").sql()
'SELECT 1 FROM db.tbl AS tbl'
>>>
+ >>> expression = sqlglot.parse_one("SELECT * FROM (tbl)")
+ >>> qualify_tables(expression).sql()
+ 'SELECT * FROM tbl AS tbl'
+ >>>
>>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)")
>>> qualify_tables(expression).sql()
- 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0'
+ 'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2'
+
+ Note:
+ This rule effectively enforces a left-to-right join order, since all joins
+ are unnested. This means that the optimizer doesn't necessarily preserve the
+ original join order, e.g. when parentheses are used to specify it explicitly.
Args:
expression: Expression to qualify
@@ -36,19 +44,11 @@ def qualify_tables(
Returns:
The qualified expression.
-
- (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html
"""
next_alias_name = name_sequence("_q_")
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
- # Expand join construct
- if isinstance(derived_table, exp.Subquery):
- unnested = derived_table.unnest()
- if isinstance(unnested, exp.Table):
- derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False))
-
if not derived_table.args.get("alias"):
alias_ = next_alias_name()
derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_)))
@@ -66,13 +66,17 @@ def qualify_tables(
if not source.args.get("catalog"):
source.set("catalog", exp.to_identifier(catalog))
+ # Unnest joins attached in tables by appending them to the closest query
+ for join in source.args.get("joins") or []:
+ scope.expression.append("joins", join)
+
+ source.set("joins", None)
+ source.set("wrapped", None)
+
if not source.alias:
source = source.replace(
alias(
- source,
- name or source.name or next_alias_name(),
- copy=True,
- table=True,
+ source, name or source.name or next_alias_name(), copy=True, table=True
)
)
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index bc649e4..7dcfb37 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -548,9 +548,6 @@ def _traverse_scope(scope):
yield from _traverse_union(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
- elif isinstance(scope.expression, exp.Table):
- # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
- yield from _traverse_tables(scope)
elif isinstance(scope.expression, exp.UDTF):
pass
else:
@@ -632,8 +629,9 @@ def _traverse_tables(scope):
if from_:
expressions.append(from_.this)
- for join in scope.expression.args.get("joins") or []:
- expressions.append(join.this)
+ for expression in (scope.expression, *scope.find_all(exp.Table)):
+ for join in expression.args.get("joins") or []:
+ expressions.append(join.this)
if isinstance(scope.expression, exp.Table):
expressions.append(scope.expression)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index f7fd6ba..c7f4050 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -1969,10 +1969,31 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
- # early return so that subquery unions aren't parsed again
- # SELECT * FROM (SELECT 1) UNION ALL SELECT 1
- # Union ALL should be a property of the top select node, not the subquery
- return self._parse_subquery(this, parse_alias=parse_subquery_alias)
+ alias = None
+
+ # Ensure "wrapped" tables are not parsed as Subqueries. The exception to this is when there's
+ # an alias that can be applied to the parentheses, because that would shadow all wrapped table
+ # names, and so we want to parse it as a Subquery to represent the inner scope appropriately.
+ # Additionally, we want the node under the Subquery to be an actual query, so we will replace
+ # the table reference with a star query that selects from it.
+ if isinstance(this, exp.Table):
+ alias = self._parse_table_alias()
+ if not alias:
+ this.set("wrapped", True)
+ return this
+
+ this.set("wrapped", None)
+ joins = this.args.pop("joins", None)
+ this = this.replace(exp.select("*").from_(this.copy(), copy=False))
+ this.set("joins", joins)
+
+ subquery = self._parse_subquery(this, parse_alias=parse_subquery_alias and not alias)
+ if subquery and alias:
+ subquery.set("alias", alias)
+
+ # We return early here so that the UNION isn't attached to the subquery by the
+ # following call to _parse_set_operations, but instead becomes the parent node
+ return subquery
elif self._match(TokenType.VALUES):
this = self.expression(
exp.Values,
@@ -2292,6 +2313,7 @@ class Parser(metaclass=_Parser):
else:
joins = None
self._retreat(index)
+
kwargs["this"].set("joins", joins)
return self.expression(exp.Join, **kwargs)