diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-10 11:29:05 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-10 11:29:05 +0000 |
commit | f818ab3b896d52e874634b7c4db3533078c1887f (patch) | |
tree | 8d0f7e4b7f165f33f49da74cb34eb31a0a2d147b /sqlglot/optimizer | |
parent | Releasing debian version 6.2.8-1. (diff) | |
download | sqlglot-f818ab3b896d52e874634b7c4db3533078c1887f.tar.xz sqlglot-f818ab3b896d52e874634b7c4db3533078c1887f.zip |
Merging upstream version 6.3.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 158 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 44 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 38 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 8 | ||||
-rw-r--r-- | sqlglot/optimizer/schema.py | 63 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 20 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 8 |
7 files changed, 286 insertions, 53 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 3f5f089..a2cef37 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,16 +1,20 @@ from sqlglot import exp from sqlglot.helper import ensure_list, subclasses +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import Scope, traverse_scope def annotate_types(expression, schema=None, annotators=None, coerces_to=None): """ Recursively infer & annotate types in an expression syntax tree against a schema. + Assumes that we've already executed the optimizer's qualify_columns step. - (TODO -- replace this with a better example after adding some functionality) Example: >>> import sqlglot - >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) - >>> annotated_expression.type + >>> schema = {"y": {"cola": "SMALLINT"}} + >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" + >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) + >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'> Args: @@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): sqlglot.Expression: expression annotated with types """ + schema = ensure_schema(schema) + return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) @@ -35,10 +41,81 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_cast(expr), - exp.DataType: lambda self, expr: self._annotate_data_type(expr), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_boolean(expr), + exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), + exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html @@ -97,43 +174,82 @@ class TypeAnnotator: }, } + TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + def __init__(self, schema=None, annotators=None, coerces_to=None): self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO def annotate(self, expression): + if isinstance(expression, self.TRAVERSABLES): + for scope in traverse_scope(expression): + subscope_selects = { + name: {select.alias_or_name: select for select in source.selects} + for name, source in scope.sources.items() + if isinstance(source, Scope) + } + + # First annotate the current scope's column references + for col in scope.columns: + source = scope.sources[col.table] + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + else: + col.type = subscope_selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + + return self._maybe_annotate(expression) # This takes care of non-traversable expressions + + def _maybe_annotate(self, expression): if not isinstance(expression, exp.Expression): return None + if expression.type: + return expression # We've already inferred the expression's type + annotator = self.annotators.get(expression.__class__) - return annotator(self, expression) if annotator else self._annotate_args(expression) + return ( + annotator(self, expression) + if annotator + else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) + ) def _annotate_args(self, expression): for value in expression.args.values(): for v in ensure_list(value): - self.annotate(v) + self._maybe_annotate(v) return expression - def _annotate_cast(self, expression): - expression.type = expression.args["to"].this - return self._annotate_args(expression) - - def _annotate_data_type(self, expression): - expression.type = expression.this - return self._annotate_args(expression) - def _maybe_coerce(self, type1, type2): + # We propagate the NULL / UNKNOWN types upwards if found + if exp.DataType.Type.NULL in (type1, type2): + return exp.DataType.Type.NULL + if exp.DataType.Type.UNKNOWN in (type1, type2): + return exp.DataType.Type.UNKNOWN + return type2 if type2 in self.coerces_to[type1] else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - if isinstance(expression, (exp.Condition, exp.Predicate)): + left_type = expression.left.type + right_type = expression.right.type + + if isinstance(expression, (exp.And, exp.Or)): + if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: + expression.type = exp.DataType.Type.NULL + elif exp.DataType.Type.NULL in (left_type, right_type): + expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + else: + expression.type = exp.DataType.Type.BOOLEAN + elif isinstance(expression, (exp.Condition, exp.Predicate)): expression.type = exp.DataType.Type.BOOLEAN else: - expression.type = self._maybe_coerce(expression.left.type, expression.right.type) + expression.type = self._maybe_coerce(left_type, right_type) return expression @@ -157,6 +273,6 @@ class TypeAnnotator: return expression - def _annotate_boolean(self, expression): - expression.type = exp.DataType.Type.BOOLEAN - return expression + def _annotate_with_type(self, expression, target_type): + expression.type = target_type + return self._annotate_args(expression) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index d29c22b..3e435f5 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -44,6 +44,7 @@ UNMERGABLE_ARGS = set(exp.Select.arg_types) - { "joins", "where", "order", + "hint", } @@ -67,21 +68,22 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: inner_select = inner_scope.expression.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): - from_or_join = table.find_ancestor(exp.From, exp.Join) - + from_or_join = table.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): node_to_replace = table if isinstance(node_to_replace.parent, exp.Alias): node_to_replace = node_to_replace.parent alias = node_to_replace.alias else: alias = table.name + _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, node_to_replace, alias) _merge_expressions(outer_scope, inner_scope, alias) _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) return expression @@ -90,9 +92,9 @@ def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: inner_select = subquery.unnest() - if _mergeable(outer_scope, inner_select, leave_tables_isolated): + from_or_join = subquery.find_ancestor(exp.From, exp.Join) + if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): alias = subquery.alias_or_name - from_or_join = subquery.find_ancestor(exp.From, exp.Join) inner_scope = outer_scope.sources[alias] _rename_inner_sources(outer_scope, inner_scope, alias) @@ -101,10 +103,11 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_joins(outer_scope, inner_scope, from_or_join) _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) + _merge_hints(outer_scope, inner_scope) return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated): +def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. @@ -112,6 +115,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): outer_scope (Scope) inner_select (exp.Select) leave_tables_isolated (bool) + from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ @@ -123,6 +127,16 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated): and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) + and not ( + isinstance(from_or_join, exp.Join) + and inner_select.args.get("where") + and from_or_join.side in {"FULL", "LEFT", "RIGHT"} + ) + and not ( + isinstance(from_or_join, exp.From) + and inner_select.args.get("where") + and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])) + ) ) @@ -170,6 +184,12 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): """ new_subquery = inner_scope.expression.args.get("from").expressions[0] node_to_replace.replace(new_subquery) + for join_hint in outer_scope.join_hints: + tables = join_hint.find_all(exp.Table) + for table in tables: + if table.alias_or_name == node_to_replace.alias_or_name: + new_table = new_subquery.this if isinstance(new_subquery, exp.Alias) else new_subquery + table.set("this", exp.to_identifier(new_table.alias_or_name)) outer_scope.remove_source(alias) outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]) @@ -273,6 +293,18 @@ def _merge_order(outer_scope, inner_scope): outer_scope.expression.set("order", inner_scope.expression.args.get("order")) +def _merge_hints(outer_scope, inner_scope): + inner_scope_hint = inner_scope.expression.args.get("hint") + if not inner_scope_hint: + return + outer_scope_hint = outer_scope.expression.args.get("hint") + if outer_scope_hint: + for hint_expression in inner_scope_hint.expressions: + outer_scope_hint.append("expressions", hint_expression) + else: + outer_scope.expression.set("hint", inner_scope_hint) + + def _pop_cte(inner_scope): """ Remove CTE from the AST. diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index a070d70..9c8d71d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from sqlglot import exp from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import traverse_scope @@ -20,22 +22,30 @@ def pushdown_predicates(expression): Returns: sqlglot.Expression: optimized expression """ - for scope in reversed(traverse_scope(expression)): + scope_ref_count = defaultdict(lambda: 0) + scopes = traverse_scope(expression) + scopes.reverse() + + for scope in scopes: + for _, source in scope.selected_sources.values(): + scope_ref_count[id(source)] += 1 + + for scope in scopes: select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources) + pushdown(where.this, scope.selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself for join in select.args.get("joins") or []: name = join.this.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression -def pushdown(condition, sources): +def pushdown(condition, sources, scope_ref_count): if not condition: return @@ -45,17 +55,17 @@ def pushdown(condition, sources): predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition]) if cnf_like: - pushdown_cnf(predicates, sources) + pushdown_cnf(predicates, sources, scope_ref_count) else: - pushdown_dnf(predicates, sources) + pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope): +def pushdown_cnf(predicates, scope, scope_ref_count): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ for predicate in predicates: - for node in nodes_for_predicate(predicate, scope).values(): + for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): predicate.replace(exp.TRUE) node.on(predicate, copy=False) @@ -65,7 +75,7 @@ def pushdown_cnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def pushdown_dnf(predicates, scope): +def pushdown_dnf(predicates, scope, scope_ref_count): """ If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form. @@ -91,7 +101,7 @@ def pushdown_dnf(predicates, scope): # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) for table in sorted(pushdown_tables): for predicate in predicates: - nodes = nodes_for_predicate(predicate, scope) + nodes = nodes_for_predicate(predicate, scope, scope_ref_count) if table not in nodes: continue @@ -120,7 +130,7 @@ def pushdown_dnf(predicates, scope): node.where(replace_aliases(node, predicate), copy=False) -def nodes_for_predicate(predicate, sources): +def nodes_for_predicate(predicate, sources, scope_ref_count): nodes = {} tables = exp.column_table_names(predicate) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) @@ -133,7 +143,7 @@ def nodes_for_predicate(predicate, sources): if node and where_condition: node = node.find_ancestor(exp.Join, exp.From) - # a node can reference a CTE which should be push down + # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): node = source.expression @@ -142,7 +152,9 @@ def nodes_for_predicate(predicate, sources): return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: - if not node.args.get("group"): + # we can't push down predicates to select statements if they are referenced in + # multiple places. + if not node.args.get("group") and scope_ref_count[id(source)] < 2: nodes[table] = node return nodes diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 72ce256..7d77ef1 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -31,8 +31,8 @@ def qualify_columns(expression, schema): _pop_table_column_aliases(scope.derived_tables) _expand_using(scope, resolver) _expand_group_by(scope, resolver) - _expand_order_by(scope) _qualify_columns(scope, resolver) + _expand_order_by(scope) if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) @@ -235,7 +235,7 @@ def _expand_stars(scope, resolver): for table in tables: if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") - columns = resolver.get_source_columns(table) + columns = resolver.get_source_columns(table, only_visible=True) table_id = id(table) for name in columns: if name not in except_columns.get(table_id, set()): @@ -332,7 +332,7 @@ class _Resolver: self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns) return self._all_columns - def get_source_columns(self, name): + def get_source_columns(self, name, only_visible=False): """Resolve the source columns for a given source `name`""" if name not in self.scope.sources: raise OptimizeError(f"Unknown table: {name}") @@ -342,7 +342,7 @@ class _Resolver: # If referencing a table, return the columns from the schema if isinstance(source, exp.Table): try: - return self.schema.column_names(source) + return self.schema.column_names(source, only_visible) except Exception as e: raise OptimizeError(str(e)) from e diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py index 1bbd86a..d7743c9 100644 --- a/sqlglot/optimizer/schema.py +++ b/sqlglot/optimizer/schema.py @@ -9,16 +9,28 @@ class Schema(abc.ABC): """Abstract base class for database schemas""" @abc.abstractmethod - def column_names(self, table): + def column_names(self, table, only_visible=False): """ Get the column names for a table. - Args: table (sqlglot.expressions.Table): Table expression instance + only_visible (bool): Whether to include invisible columns Returns: list[str]: list of column names """ + @abc.abstractmethod + def get_column_type(self, table, column): + """ + Get the exp.DataType type of a column in the schema. + + Args: + table (sqlglot.expressions.Table): The source table. + column (sqlglot.expressions.Column): The target column. + Returns: + sqlglot.expressions.DataType.Type: The resulting column type. + """ + class MappingSchema(Schema): """ @@ -29,10 +41,19 @@ class MappingSchema(Schema): 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} + visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + are assumed to be visible. The nesting should mirror that of the schema: + 1. {table: set(*cols)}} + 2. {db: {table: set(*cols)}}} + 3. {catalog: {db: {table: set(*cols)}}}} + dialect (str): The dialect to be used for custom type mappings. """ - def __init__(self, schema): + def __init__(self, schema, visible=None, dialect=None): self.schema = schema + self.visible = visible + self.dialect = dialect + self._type_mapping_cache = {} depth = _dict_depth(schema) @@ -49,7 +70,7 @@ class MappingSchema(Schema): self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) - def column_names(self, table): + def column_names(self, table, only_visible=False): if not isinstance(table.this, exp.Identifier): return fs_get(table) @@ -58,7 +79,39 @@ class MappingSchema(Schema): for forbidden in self.forbidden_args: if table.text(forbidden): raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") - return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + if not only_visible or not self.visible: + return columns + + visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) + return [col for col in columns if col in visible] + + def get_column_type(self, table, column): + try: + schema_type = self.schema.get(table.name, {}).get(column.name).upper() + return self._convert_type(schema_type) + except: + raise OptimizeError(f"Failed to get type for column {column.sql()}") + + def _convert_type(self, schema_type): + """ + Convert a type represented as a string to the corresponding exp.DataType.Type object. + + Args: + schema_type (str): The type we want to convert. + Returns: + sqlglot.expressions.DataType.Type: The resulting expression type. + """ + if schema_type not in self._type_mapping_cache: + try: + self._type_mapping_cache[schema_type] = exp.maybe_parse( + schema_type, into=exp.DataType, dialect=self.dialect + ).this + except AttributeError: + raise OptimizeError(f"Failed to convert type {schema_type}") + + return self._type_mapping_cache[schema_type] def ensure_schema(schema): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 6332cdd..89de517 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -68,6 +68,7 @@ class Scope: self._selected_sources = None self._columns = None self._external_columns = None + self._join_hints = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -85,14 +86,17 @@ class Scope: self._subqueries = [] self._derived_tables = [] self._raw_columns = [] + self._join_hints = [] for node, parent, _ in self.walk(bfs=False): if node is self.expression: continue elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) - elif isinstance(node, exp.Table): + elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): self._tables.append(node) + elif isinstance(node, exp.JoinHint): + self._join_hints.append(node) elif isinstance(node, exp.UDTF): self._derived_tables.append(node) elif isinstance(node, exp.CTE): @@ -246,7 +250,7 @@ class Scope: table only becomes a selected source if it's included in a FROM or JOIN clause. Returns: - dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes + dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes """ if self._selected_sources is None: referenced_names = [] @@ -310,6 +314,18 @@ class Scope: self._external_columns = [c for c in self.columns if c.table not in self.selected_sources] return self._external_columns + @property + def join_hints(self): + """ + Hints that exist in the scope that reference tables + + Returns: + list[exp.JoinHint]: Join hints that are referenced within the scope + """ + if self._join_hints is None: + return [] + return self._join_hints + def source_columns(self, source_name): """ Get all columns in the current scope for a particular source. diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 319e6b6..c077906 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -56,12 +56,16 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): + if isinstance(expression.this, exp.Null): + return NULL if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() if isinstance(condition, exp.And): return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + if isinstance(condition, exp.Null): + return NULL if always_true(expression.this): return FALSE if expression.this == FALSE: @@ -95,10 +99,10 @@ def simplify_connectors(expression): return left if isinstance(expression, exp.And): - if NULL in (left, right): - return NULL if FALSE in (left, right): return FALSE + if NULL in (left, right): + return NULL if always_true(left) and always_true(right): return TRUE if always_true(left): |