diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-03-19 10:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-03-19 10:22:09 +0000 |
commit | 66af5c6fc22f6f11e9ea807b274e011a6f64efb7 (patch) | |
tree | 08ceed3b311b7b343935c1e55941b9d15e6f56d8 /sqlglot/optimizer | |
parent | Releasing debian version 11.3.6-1. (diff) | |
download | sqlglot-66af5c6fc22f6f11e9ea807b274e011a6f64efb7.tar.xz sqlglot-66af5c6fc22f6f11e9ea807b274e011a6f64efb7.zip |
Merging upstream version 11.4.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 28 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 17 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 22 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 8 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 5 |
5 files changed, 62 insertions, 18 deletions
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index fc37a54..c5c780d 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import itertools from sqlglot import exp +from sqlglot.helper import should_identify -def canonicalize(expression: exp.Expression) -> exp.Expression: +def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: Args: expression: The expression to canonicalize. + identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize) + exp.replace_children(expression, canonicalize, identify=identify) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) + expression = ensure_bool_predicates(expression) if isinstance(expression, exp.Identifier): - expression.set("quoted", True) + if should_identify(expression.this, identify): + expression.set("quoted", True) return expression @@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: return expression +def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Connector): + _replace_int_predicate(expression.left) + _replace_int_predicate(expression.right) + + elif isinstance(expression, (exp.Where, exp.Having)): + _replace_int_predicate(expression.this) + + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): if ( @@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None: cast = exp.Cast(this=node.copy(), to=data_type) cast.type = data_type node.replace(cast) + + +def _replace_int_predicate(expression: exp.Expression) -> None: + if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: + expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 07a1b70..2e51117 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -1,7 +1,6 @@ from collections import defaultdict from sqlglot import alias, exp -from sqlglot.helper import flatten from sqlglot.optimizer.qualify_columns import Resolver from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema): else: order_refs = set() - new_selections = defaultdict(list) + new_selections = [] removed = False star = False + for selection in scope.selects: name = selection.alias_or_name if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: - new_selections[name].append(selection) + new_selections.append(selection) else: if selection.is_star: star = True @@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema): if star: resolver = Resolver(scope, schema) + names = {s.alias_or_name for s in new_selections} for name in sorted(parent_selections): - if name not in new_selections: - new_selections[name].append( - alias(exp.column(name, table=resolver.get_table(name)), name) - ) + if name not in names: + new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name)) # If there are no remaining selections, just select a single constant if not new_selections: - new_selections[""].append(DEFAULT_SELECTION()) + new_selections.append(DEFAULT_SELECTION()) - scope.expression.select(*flatten(new_selections.values()), append=False, copy=False) + scope.expression.select(*new_selections, append=False, copy=False) if removed: scope.clear_cache() diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e793e31..66b3170 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -37,6 +37,7 @@ def qualify_columns(expression, schema): _qualify_outputs(scope) _expand_group_by(scope, resolver) _expand_order_by(scope) + return expression @@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver): # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", column_table) + elif column_table not in scope.sources: + # structs are used like tables (e.g. "struct"."field"), so they need to be qualified + # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) + + root, *parts = column.parts + + if root.name in scope.sources: + # struct is already qualified, but we still need to change the AST representation + column_table = root + root, *parts = parts + else: + column_table = resolver.get_table(root.name) + + if column_table: + column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) columns_missing_from_scope = [] # Determine whether each reference in the order by clause is to a column or an alias. @@ -373,10 +389,14 @@ class Resolver: if isinstance(node, exp.Subqueryable): while node and node.alias != table_name: node = node.parent + node_alias = node.args.get("alias") if node_alias: return node_alias.this - return exp.to_identifier(table_name) + + return exp.to_identifier( + table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None + ) @property def all_columns(self): diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 6e50182..93e1179 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) - for source in scope.sources.values(): + for name, source in scope.sources.items(): if isinstance(source, exp.Table): - identifier = isinstance(source.this, exp.Identifier) - - if identifier: + if isinstance(source.this, exp.Identifier): if not source.args.get("db"): source.set("db", exp.to_identifier(db)) if not source.args.get("catalog"): @@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): source = source.replace( alias( source.copy(), - source.this if identifier else next_name(), + name if name else next_name(), table=True, ) ) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 335ff3e..9c0768c 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -4,6 +4,7 @@ from enum import Enum, auto from sqlglot import exp from sqlglot.errors import OptimizeError +from sqlglot.helper import find_new_name class ScopeType(Enum): @@ -293,6 +294,8 @@ class Scope: result = {} for name, node in referenced_names: + if name in result: + raise OptimizeError(f"Alias already used: {name}") if name in self.sources: result[name] = (node, self.sources[name]) @@ -594,6 +597,8 @@ def _traverse_tables(scope): if table_name in scope.sources: # This is a reference to a parent source (e.g. a CTE), not an actual table. sources[source_name] = scope.sources[table_name] + elif source_name in sources: + sources[find_new_name(sources, table_name)] = expression else: sources[source_name] = expression continue |