diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-16 09:41:18 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-16 09:41:18 +0000 |
commit | 67578a7602a5be7eb51f324086c8d49bcf8b7498 (patch) | |
tree | 0b7515c922d1c383cea24af5175379cfc8edfd15 /sqlglot/optimizer | |
parent | Releasing debian version 15.2.0-1. (diff) | |
download | sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.tar.xz sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.zip |
Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 516 | ||||
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/isolate_table_selects.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 9 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 33 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 8 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 2 |
11 files changed, 288 insertions, 302 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 6238759..39e2c53 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,13 +1,25 @@ +from __future__ import annotations + +import typing as t + from sqlglot import exp +from sqlglot._typing import E from sqlglot.helper import ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema +from sqlglot.schema import Schema, ensure_schema + +if t.TYPE_CHECKING: + B = t.TypeVar("B", bound=exp.Binary) -def annotate_types(expression, schema=None, annotators=None, coerces_to=None): +def annotate_types( + expression: E, + schema: t.Optional[t.Dict | Schema] = None, + annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, +) -> E: """ - Recursively infer & annotate types in an expression syntax tree against a schema. - Assumes that we've already executed the optimizer's qualify_columns step. + Infers the types of an expression, annotating its AST accordingly. Example: >>> import sqlglot @@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): <Type.DOUBLE: 'DOUBLE'> Args: - expression (sqlglot.Expression): Expression to annotate. - schema (dict|sqlglot.optimizer.Schema): Database schema. - annotators (dict): Maps expression type to corresponding annotation function. - coerces_to (dict): Maps expression type to set of types that it can be coerced into. + expression: Expression to annotate. + schema: Database schema. + annotators: Maps expression type to corresponding annotation function. + coerces_to: Maps expression type to set of types that it can be coerced into. + Returns: - sqlglot.Expression: expression annotated with types + The expression annotated with types. """ schema = ensure_schema(schema) @@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) -class TypeAnnotator: - ANNOTATORS = { - **{ - expr_type: lambda self, expr: self._annotate_unary(expr) - for expr_type in subclasses(exp.__name__, exp.Unary) - }, - **{ - expr_type: lambda self, expr: self._annotate_binary(expr) - for expr_type in subclasses(exp.__name__, exp.Binary) - }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), - exp.Alias: lambda self, expr: self._annotate_unary(expr), - exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), - exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.BIGINT - ), - exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.Sum: lambda self, expr: self._annotate_by_args( - expr, "this", "expressions", promote=True - ), - exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.CurrentTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.TimestampSub: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATE - ), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), - exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), - exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), - exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.GroupConcat: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.ArrayConcat: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), - exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), - exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), - exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), - exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DOUBLE - ), - exp.RegexpLike: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.BOOLEAN - ), - exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATE - ), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DOUBLE - ), - exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - } +def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: + return lambda self, e: self._annotate_with_type(e, data_type) - # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html - COERCES_TO = { - # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT - exp.DataType.Type.TEXT: set(), - exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, - exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: { - exp.DataType.Type.VARCHAR, - exp.DataType.Type.NVARCHAR, + +class _TypeAnnotator(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): + # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html + text_precedence = ( exp.DataType.Type.TEXT, - }, - exp.DataType.Type.CHAR: { - exp.DataType.Type.NCHAR, - exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, - exp.DataType.Type.TEXT, - }, - # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE - exp.DataType.Type.DOUBLE: set(), - exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, - exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: { - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NCHAR, + exp.DataType.Type.CHAR, + ) + numeric_precedence = ( exp.DataType.Type.DOUBLE, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.BIGINT, + exp.DataType.Type.INT, + exp.DataType.Type.SMALLINT, + exp.DataType.Type.TINYINT, + ) + timelike_precedence = ( + exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.DATETIME, + exp.DataType.Type.DATE, + ) + + for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): + coerces_to = set() + for data_type in type_precedence: + klass.COERCES_TO[data_type] = coerces_to.copy() + coerces_to |= {data_type} + + return klass + + +class TypeAnnotator(metaclass=_TypeAnnotator): + TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { + exp.DataType.Type.BIGINT: { + exp.ApproxDistinct, + exp.ArraySize, + exp.Count, + exp.Length, + }, + exp.DataType.Type.BOOLEAN: { + exp.Between, + exp.Boolean, + exp.In, + exp.RegexpLike, + }, + exp.DataType.Type.DATE: { + exp.CurrentDate, + exp.Date, + exp.DateAdd, + exp.DateStrToDate, + exp.DateSub, + exp.DateTrunc, + exp.DiToDate, + exp.StrToDate, + exp.TimeStrToDate, + exp.TsOrDsToDate, + }, + exp.DataType.Type.DATETIME: { + exp.CurrentDatetime, + exp.DatetimeAdd, + exp.DatetimeSub, + }, + exp.DataType.Type.DOUBLE: { + exp.ApproxQuantile, + exp.Avg, + exp.Exp, + exp.Ln, + exp.Log, + exp.Log2, + exp.Log10, + exp.Pow, + exp.Quantile, + exp.Round, + exp.SafeDivide, + exp.Sqrt, + exp.Stddev, + exp.StddevPop, + exp.StddevSamp, + exp.Variance, + exp.VariancePop, }, exp.DataType.Type.INT: { - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.Ceil, + exp.DateDiff, + exp.DatetimeDiff, + exp.Extract, + exp.TimestampDiff, + exp.TimeDiff, + exp.DateToDi, + exp.Floor, + exp.Levenshtein, + exp.StrPosition, + exp.TsOrDiToDi, }, - exp.DataType.Type.SMALLINT: { - exp.DataType.Type.INT, - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.DataType.Type.TIMESTAMP: { + exp.CurrentTime, + exp.CurrentTimestamp, + exp.StrToTime, + exp.TimeAdd, + exp.TimeStrToTime, + exp.TimeSub, + exp.TimestampAdd, + exp.TimestampSub, + exp.UnixToTime, }, exp.DataType.Type.TINYINT: { - exp.DataType.Type.SMALLINT, - exp.DataType.Type.INT, - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.Day, + exp.Month, + exp.Week, + exp.Year, }, - # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ - exp.DataType.Type.TIMESTAMPLTZ: set(), - exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: { - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.VARCHAR: { + exp.ArrayConcat, + exp.Concat, + exp.ConcatWs, + exp.DateToDateStr, + exp.GroupConcat, + exp.Initcap, + exp.Lower, + exp.SafeConcat, + exp.Substring, + exp.TimeToStr, + exp.TimeToTimeStr, + exp.Trim, + exp.TsOrDsToDateStr, + exp.UnixToStr, + exp.UnixToTimeStr, + exp.Upper, }, - exp.DataType.Type.DATETIME: { - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + } + + ANNOTATORS = { + **{ + expr_type: lambda self, e: self._annotate_unary(e) + for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) }, - exp.DataType.Type.DATE: { - exp.DataType.Type.DATETIME, - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + **{ + expr_type: lambda self, e: self._annotate_binary(e) + for expr_type in subclasses(exp.__name__, exp.Binary) + }, + **{ + expr_type: _annotate_with_type_lambda(data_type) + for data_type, expressions in TYPE_TO_EXPRESSIONS.items() + for expr_type in expressions }, + exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), + exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), + exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), + exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Literal: lambda self, e: self._annotate_literal(e), + exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), + exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), + exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), } - TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + # Specifies what types a given type can be coerced into (autofilled) + COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} - def __init__(self, schema=None, annotators=None, coerces_to=None): + def __init__( + self, + schema: Schema, + annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + ) -> None: self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO - def annotate(self, expression): - if isinstance(expression, self.TRAVERSABLES): - for scope in traverse_scope(expression): - selects = {} - for name, source in scope.sources.items(): - if not isinstance(source, Scope): - continue - if isinstance(source.expression, exp.UDTF): - values = [] - - if isinstance(source.expression, exp.Lateral): - if isinstance(source.expression.this, exp.Explode): - values = [source.expression.this.this] - else: - values = source.expression.expressions[0].expressions - - if not values: - continue - - selects[name] = { - alias: column - for alias, column in zip( - source.expression.alias_column_names, - values, - ) - } + def annotate(self, expression: E) -> E: + for scope in traverse_scope(expression): + selects = {} + for name, source in scope.sources.items(): + if not isinstance(source, Scope): + continue + if isinstance(source.expression, exp.UDTF): + values = [] + + if isinstance(source.expression, exp.Lateral): + if isinstance(source.expression.this, exp.Explode): + values = [source.expression.this.this] else: - selects[name] = { - select.alias_or_name: select for select in source.expression.selects - } - # First annotate the current scope's column references - for col in scope.columns: - if not col.table: + values = source.expression.expressions[0].expressions + + if not values: continue - source = scope.sources.get(col.table) - if isinstance(source, exp.Table): - col.type = self.schema.get_column_type(source, col) - elif source and col.table in selects and col.name in selects[col.table]: - col.type = selects[col.table][col.name].type - # Then (possibly) annotate the remaining expressions in the scope - self._maybe_annotate(scope.expression) + selects[name] = { + alias: column + for alias, column in zip( + source.expression.alias_column_names, + values, + ) + } + else: + selects[name] = { + select.alias_or_name: select for select in source.expression.selects + } + + # First annotate the current scope's column references + for col in scope.columns: + if not col.table: + continue + + source = scope.sources.get(col.table) + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + elif source and col.table in selects and col.name in selects[col.table]: + col.type = selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + return self._maybe_annotate(expression) # This takes care of non-traversable expressions - def _maybe_annotate(self, expression): + def _maybe_annotate(self, expression: E) -> E: if expression.type: return expression # We've already inferred the expression's type @@ -312,13 +290,15 @@ class TypeAnnotator: else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) ) - def _annotate_args(self, expression): + def _annotate_args(self, expression: E) -> E: for _, value in expression.iter_expressions(): self._maybe_annotate(value) return expression - def _maybe_coerce(self, type1, type2): + def _maybe_coerce( + self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type + ) -> exp.DataType.Type: # We propagate the NULL / UNKNOWN types upwards if found if isinstance(type1, exp.DataType): type1 = type1.this @@ -330,9 +310,14 @@ class TypeAnnotator: if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to.get(type1, {}) else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore - def _annotate_binary(self, expression): + # Note: the following "no_type_check" decorators were added because mypy was yelling due + # to assigning Type values to expression.type (since its getter returns Optional[DataType]). + # This is a known mypy issue: https://github.com/python/mypy/issues/3004 + + @t.no_type_check + def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) left_type = expression.left.type.this @@ -354,7 +339,8 @@ class TypeAnnotator: return expression - def _annotate_unary(self, expression): + @t.no_type_check + def _annotate_unary(self, expression: E) -> E: self._annotate_args(expression) if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): @@ -364,7 +350,8 @@ class TypeAnnotator: return expression - def _annotate_literal(self, expression): + @t.no_type_check + def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: if expression.is_string: expression.type = exp.DataType.Type.VARCHAR elif expression.is_int: @@ -374,13 +361,16 @@ class TypeAnnotator: return expression - def _annotate_with_type(self, expression, target_type): + @t.no_type_check + def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args, promote=False): + @t.no_type_check + def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: self._annotate_args(expression) - expressions = [] + + expressions: t.List[exp.Expression] = [] for arg in args: arg_expr = expression.args.get(arg) expressions.extend(expr for expr in ensure_list(arg_expr) if expr) diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index da2fce8..015b06a 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: def add_text_to_concat(node: exp.Expression) -> exp.Expression: if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: - node = exp.Concat(this=node.this, expression=node.expression) + node = exp.Concat(expressions=[node.left, node.right]) return node diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 27de9c7..cd8ba3b 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -32,7 +32,7 @@ def eliminate_joins(expression): # Reverse the joins so we can remove chains of unused joins for join in reversed(joins): - alias = join.this.alias_or_name + alias = join.alias_or_name if _should_eliminate_join(scope, join, alias): join.pop() scope.remove_source(alias) @@ -126,7 +126,7 @@ def join_condition(join): tuple[list[str], list[str], exp.Expression]: Tuple of (source key, join key, remaining predicate) """ - name = join.this.alias_or_name + name = join.alias_or_name on = (join.args.get("on") or exp.true()).copy() source_key = [] join_key = [] diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 5dfa4aa..79e3ed5 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None): source.replace( exp.select("*") .from_( - alias(source, source.name or source.alias, table=True), + alias(source, source.alias_or_name, table=True), copy=False, ) .subquery(source.alias, copy=False) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index f9c9664..fefe96e 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): if not isinstance(from_or_join, exp.Join): return False - alias = from_or_join.this.alias_or_name + alias = from_or_join.alias_or_name on = from_or_join.args.get("on") if not on: @@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join): """ new_joins = [] - comma_joins = inner_scope.expression.args.get("from").expressions[1:] - for subquery in comma_joins: - new_joins.append(exp.Join(this=subquery, kind="CROSS")) - outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name]) joins = inner_scope.expression.args.get("joins") or [] for join in joins: @@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if source == from_or_join.alias_or_name: break - if set(exp.column_table_names(where.this)) <= sources: + if exp.column_table_names(where.this) <= sources: from_or_join.on(where.this, copy=False) from_or_join.set("on", from_or_join.args.get("on")) return expression.where(where.this, copy=False) - expression.set("where", expression.args.get("where")) def _merge_order(outer_scope, inner_scope): diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 4e0c3a1..d51276f 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + from sqlglot import exp from sqlglot.helper import tsort @@ -13,25 +17,28 @@ def optimize_joins(expression): >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' """ + for select in expression.find_all(exp.Select): references = {} cross_joins = [] for join in select.args.get("joins", []): - name = join.this.alias_or_name - tables = other_table_names(join, name) + tables = other_table_names(join) if tables: for table in tables: references[table] = references.get(table, []) + [join] else: - cross_joins.append((name, join)) + cross_joins.append((join.alias_or_name, join)) for name, join in cross_joins: for dep in references.get(name, []): on = dep.args["on"] if isinstance(on, exp.Connector): + if len(other_table_names(dep)) < 2: + continue + for predicate in on.flatten(): if name in exp.column_table_names(predicate): predicate.replace(exp.true()) @@ -47,17 +54,12 @@ def reorder_joins(expression): Reorder joins by topological sort order based on predicate references. """ for from_ in expression.find_all(exp.From): - head = from_.this parent = from_.parent - joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} - dag = {head.alias_or_name: []} - - for name, join in joins.items(): - dag[name] = other_table_names(join, name) - + joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {name: other_table_names(join) for name, join in joins.items()} parent.set( "joins", - [joins[name] for name in tsort(dag) if name != head.alias_or_name], + [joins[name] for name in tsort(dag) if name != from_.alias_or_name], ) return expression @@ -75,9 +77,6 @@ def normalize(expression): return expression -def other_table_names(join, exclude): - return [ - name - for name in (exp.column_table_names(join.args.get("on") or exp.true())) - if name != exclude - ] +def other_table_names(join: exp.Join) -> t.Set[str]: + on = join.args.get("on") + return exp.column_table_names(on, join.alias_or_name) if on else set() diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index dbe33a2..abac63b 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -78,7 +78,7 @@ def optimize( "schema": schema, "dialect": dialect, "isolate_tables": True, # needed for other optimizations to perform well - "quote_identifiers": False, # this happens in canonicalize + "quote_identifiers": False, **kwargs, } diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index b89a82b..fb1662d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -41,7 +41,7 @@ def pushdown_predicates(expression): # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself for join in select.args.get("joins") or []: - name = join.this.alias_or_name + name = join.alias_or_name pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression @@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count): pushdown_tables = set() for a in predicates: - a_tables = set(exp.column_table_names(a)) + a_tables = exp.column_table_names(a) for b in predicates: - a_tables &= set(exp.column_table_names(b)) + a_tables &= exp.column_table_names(b) pushdown_tables.update(a_tables) @@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): tables = exp.column_table_names(predicate) where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) - for table in tables: + for table in sorted(tables): node, source = sources.get(table) or (None, None) # if the predicate is in a where statement we can try to push it down diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 4a31171..aba9a7e 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema def qualify_columns( expression: exp.Expression, - schema: dict | Schema, + schema: t.Dict | Schema, expand_alias_refs: bool = True, infer_schema: t.Optional[bool] = None, ) -> exp.Expression: @@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables): def _expand_using(scope, resolver): joins = list(scope.find_all(exp.Join)) - names = {join.this.alias for join in joins} + names = {join.alias_or_name for join in joins} ordered = [key for key in scope.selected_sources if key not in names] # Mapping of automatically joined column names to an ordered set of source names (dict). @@ -105,7 +105,7 @@ def _expand_using(scope, resolver): if not using: continue - join_table = join.this.alias_or_name + join_table = join.alias_or_name columns = {} diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index fcc5f26..9c931d6 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -91,11 +91,13 @@ def qualify_tables( ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression - table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name()) + table_alias = udtf.args.get("alias") or exp.TableAlias( + this=exp.to_identifier(next_alias_name()) + ) udtf.set("alias", table_alias) if not table_alias.name: - table_alias.set("this", next_alias_name()) + table_alias.set("this", exp.to_identifier(next_alias_name())) if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 9ffb4d6..aa56b83 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -620,7 +620,7 @@ def _traverse_tables(scope): table_name = expression.name source_name = expression.alias_or_name - if table_name in scope.sources: + if table_name in scope.sources and not expression.db: # This is a reference to a parent source (e.g. a CTE), not an actual table, unless # it is pivoted, because then we get back a new table and hence a new source. pivots = expression.args.get("pivots") |