diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
commit | 42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch) | |
tree | 5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/optimizer | |
parent | Releasing debian version 21.1.2-1. (diff) | |
download | sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip |
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 41 | ||||
-rw-r--r-- | sqlglot/optimizer/normalize_identifiers.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 21 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 12 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 16 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 14 |
7 files changed, 90 insertions, 24 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index ce274bb..81b1ee6 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DateToDi, exp.Floor, exp.Levenshtein, + exp.Sign, exp.StrPosition, exp.TsOrDiToDi, }, @@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Div: lambda self, e: self._annotate_div(e), + exp.Dot: lambda self, e: self._annotate_dot(e), exp.Explode: lambda self, e: self._annotate_explode(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), @@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), + exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.Timestamp: lambda self, e: self._annotate_with_type( e, exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, ), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.Unnest: lambda self, e: self._annotate_unnest(e), exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), - exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), } NESTED_TYPES = { @@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): source = scope.sources.get(col.table) if isinstance(source, exp.Table): self._set_type(col, self.schema.get_column_type(source, col)) - elif source and col.table in selects and col.name in selects[col.table]: - self._set_type(col, selects[col.table][col.name].type) + elif source: + if col.table in selects and col.name in selects[col.table]: + self._set_type(col, selects[col.table][col.name].type) + elif isinstance(source.expression, exp.Unnest): + self._set_type(col, source.expression.type) # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator): last_datatype = None for expr in expressions: - last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) + expr_type = expr.type + + # Stop at the first nested data type found - we don't want to _maybe_coerce nested types + if expr_type.args.get("nested"): + last_datatype = expr_type + break + + last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type) self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) @@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression + def _annotate_dot(self, expression: exp.Dot) -> exp.Dot: + self._annotate_args(expression) + self._set_type(expression, None) + this_type = expression.this.type + + if this_type and this_type.is_type(exp.DataType.Type.STRUCT): + for e in this_type.expressions: + if e.name == expression.expression.name: + self._set_type(expression, e.kind) + break + + return expression + def _annotate_explode(self, expression: exp.Explode) -> exp.Explode: self._annotate_args(expression) self._set_type(expression, seq_get(expression.this.type.expressions, 0)) return expression + + def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest: + self._annotate_args(expression) + child = seq_get(expression.expressions, 0) + self._set_type(expression, child and seq_get(child.type.expressions, 0)) + return expression diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index d22a998..f2a0990 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -10,13 +10,11 @@ if t.TYPE_CHECKING: @t.overload -def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: - ... +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ... @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: - ... +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... def normalize_identifiers(expression, dialect=None): diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ef589c9..233ffc9 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: + if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: + continue table_alias = derived_table.args.get("alias") if table_alias: table_alias.args.pop("columns", None) @@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: table = resolver.get_table(column.name) if resolve_table and not column.table else None alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( - (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) + ( + alias_expr.find(exp.AggFunc) + and ( + column.find_ancestor(exp.AggFunc) + and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) + ) + ) if alias_expr else False ) @@ -404,7 +412,7 @@ def _expand_stars( tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif expression.is_star: + elif expression.is_star and not isinstance(expression, exp.Dot): tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -437,7 +445,7 @@ def _expand_stars( if pivot_columns: new_selections.extend( - exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) + alias(exp.column(name, table=pivot.alias), name, copy=False) for name in pivot_columns if name not in columns_to_exclude ) @@ -466,7 +474,7 @@ def _expand_stars( ) # Ensures we don't overwrite the initial selections with an empty list - if new_selections: + if new_selections and isinstance(scope.expression, exp.Select): scope.expression.set("expressions", new_selections) @@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: new_selections.append(selection) - scope.expression.set("expressions", new_selections) + if isinstance(scope.expression, exp.Select): + scope.expression.set("expressions", new_selections) def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: @@ -615,7 +624,7 @@ class Resolver: node, _ = self.scope.selected_sources.get(table_name) - if isinstance(node, exp.Subqueryable): + if isinstance(node, exp.Query): while node and node.alias != table_name: node = node.parent diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index d460e81..214ac0a 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -55,8 +55,8 @@ def qualify_tables( if not table.args.get("catalog") and table.args.get("db"): table.set("catalog", catalog) - if not isinstance(expression, exp.Subqueryable): - for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): + if not isinstance(expression, exp.Query): + for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)): if isinstance(node, exp.Table): _qualify(node) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 0eae979..443fa6c 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -138,7 +138,7 @@ class Scope: and _is_derived_table(node) ): self._derived_tables.append(node) - elif isinstance(node, exp.Subqueryable): + elif isinstance(node, exp.UNWRAPPED_QUERIES): self._subqueries.append(node) self._collected = True @@ -225,7 +225,7 @@ class Scope: SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery Returns: - list[exp.Subqueryable]: subqueries + list[exp.Select | exp.Union]: subqueries """ self._ensure_collected() return self._subqueries @@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Returns: list[Scope]: scope instances """ - if isinstance(expression, exp.Unionable) or ( - isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable) + if isinstance(expression, exp.Query) or ( + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query) ): return list(_traverse_scope(Scope(expression))) @@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool: as it doesn't introduce a new scope. If an alias is present, it shadows all names under the Subquery, so that's one exception to this rule. """ - return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) + return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)) def _traverse_tables(scope): @@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None): and _is_derived_table(node) ) or isinstance(node, exp.UDTF) - or isinstance(node, exp.Subqueryable) + or isinstance(node, exp.UNWRAPPED_QUERIES) ): crossed_scope_boundary = True diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 9ffddb5..2e43d21 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str: GEN_MAP = { exp.Add: lambda e: _binary(e, "+"), exp.And: lambda e: _binary(e, "AND"), - exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}", + exp.Anonymous: lambda e: _anonymous(e), exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", @@ -1219,6 +1219,20 @@ GEN_MAP = { } +def _anonymous(e: exp.Anonymous) -> str: + this = e.this + if isinstance(this, str): + name = this.upper() + elif isinstance(this, exp.Identifier): + name = f'"{this.name}"' if this.quoted else this.name.upper() + else: + raise ValueError( + f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + ) + + return f"{name} {','.join(gen(e) for e in e.expressions)}" + + def _binary(e: exp.Binary, op: str) -> str: return f"{gen(e.left)} {op} {gen(e.right)}" diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index b4c7475..36d9da4 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name): else: _replace(predicate, join_key_not_null) + group = select.args.get("group") + + if group: + if {value.this} != set(group.expressions): + select = ( + exp.select(exp.column(value.alias, "_q")) + .from_(select.subquery("_q", copy=False), copy=False) + .group_by(exp.column(value.alias, "_q"), copy=False) + ) + else: + select = select.group_by(value.this, copy=False) + parent_select.join( - select.group_by(value.this, copy=False), + select, on=column.eq(join_key), join_type="LEFT", join_alias=alias, |