summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-16 09:41:18 +0000
commit67578a7602a5be7eb51f324086c8d49bcf8b7498 (patch)
tree0b7515c922d1c383cea24af5175379cfc8edfd15 /sqlglot/optimizer
parentReleasing debian version 15.2.0-1. (diff)
downloadsqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.tar.xz
sqlglot-67578a7602a5be7eb51f324086c8d49bcf8b7498.zip
Merging upstream version 16.2.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py516
-rw-r--r--sqlglot/optimizer/canonicalize.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py4
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py2
-rw-r--r--sqlglot/optimizer/merge_subqueries.py9
-rw-r--r--sqlglot/optimizer/optimize_joins.py33
-rw-r--r--sqlglot/optimizer/optimizer.py2
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py8
-rw-r--r--sqlglot/optimizer/qualify_columns.py6
-rw-r--r--sqlglot/optimizer/qualify_tables.py6
-rw-r--r--sqlglot/optimizer/scope.py2
11 files changed, 288 insertions, 302 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 6238759..39e2c53 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,13 +1,25 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import exp
+from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.schema import ensure_schema
+from sqlglot.schema import Schema, ensure_schema
+
+if t.TYPE_CHECKING:
+ B = t.TypeVar("B", bound=exp.Binary)
-def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
+def annotate_types(
+ expression: E,
+ schema: t.Optional[t.Dict | Schema] = None,
+ annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+) -> E:
"""
- Recursively infer & annotate types in an expression syntax tree against a schema.
- Assumes that we've already executed the optimizer's qualify_columns step.
+ Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot
@@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
<Type.DOUBLE: 'DOUBLE'>
Args:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
+ expression: Expression to annotate.
+ schema: Database schema.
+ annotators: Maps expression type to corresponding annotation function.
+ coerces_to: Maps expression type to set of types that it can be coerced into.
+
Returns:
- sqlglot.Expression: expression annotated with types
+ The expression annotated with types.
"""
schema = ensure_schema(schema)
@@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
-class TypeAnnotator:
- ANNOTATORS = {
- **{
- expr_type: lambda self, expr: self._annotate_unary(expr)
- for expr_type in subclasses(exp.__name__, exp.Unary)
- },
- **{
- expr_type: lambda self, expr: self._annotate_binary(expr)
- for expr_type in subclasses(exp.__name__, exp.Binary)
- },
- exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
- exp.Alias: lambda self, expr: self._annotate_unary(expr),
- exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.Literal: lambda self, expr: self._annotate_literal(expr),
- exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
- exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
- exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.BIGINT
- ),
- exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.Sum: lambda self, expr: self._annotate_by_args(
- expr, "this", "expressions", promote=True
- ),
- exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.CurrentTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.TimestampSub: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATE
- ),
- exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
- exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
- exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
- exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.GroupConcat: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
- exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
- exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
- exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
- exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DOUBLE
- ),
- exp.RegexpLike: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.BOOLEAN
- ),
- exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.StrToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATE
- ),
- exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.UnixToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.VariancePop: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DOUBLE
- ),
- exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- }
+def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
+ return lambda self, e: self._annotate_with_type(e, data_type)
- # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
- COERCES_TO = {
- # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
- exp.DataType.Type.TEXT: set(),
- exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
- exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
- exp.DataType.Type.NCHAR: {
- exp.DataType.Type.VARCHAR,
- exp.DataType.Type.NVARCHAR,
+
+class _TypeAnnotator(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+
+ # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
+ # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
+ text_precedence = (
exp.DataType.Type.TEXT,
- },
- exp.DataType.Type.CHAR: {
- exp.DataType.Type.NCHAR,
- exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
- exp.DataType.Type.TEXT,
- },
- # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
- exp.DataType.Type.DOUBLE: set(),
- exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
- exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
- exp.DataType.Type.BIGINT: {
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NCHAR,
+ exp.DataType.Type.CHAR,
+ )
+ numeric_precedence = (
exp.DataType.Type.DOUBLE,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.BIGINT,
+ exp.DataType.Type.INT,
+ exp.DataType.Type.SMALLINT,
+ exp.DataType.Type.TINYINT,
+ )
+ timelike_precedence = (
+ exp.DataType.Type.TIMESTAMPLTZ,
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMP,
+ exp.DataType.Type.DATETIME,
+ exp.DataType.Type.DATE,
+ )
+
+ for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
+ coerces_to = set()
+ for data_type in type_precedence:
+ klass.COERCES_TO[data_type] = coerces_to.copy()
+ coerces_to |= {data_type}
+
+ return klass
+
+
+class TypeAnnotator(metaclass=_TypeAnnotator):
+ TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
+ exp.DataType.Type.BIGINT: {
+ exp.ApproxDistinct,
+ exp.ArraySize,
+ exp.Count,
+ exp.Length,
+ },
+ exp.DataType.Type.BOOLEAN: {
+ exp.Between,
+ exp.Boolean,
+ exp.In,
+ exp.RegexpLike,
+ },
+ exp.DataType.Type.DATE: {
+ exp.CurrentDate,
+ exp.Date,
+ exp.DateAdd,
+ exp.DateStrToDate,
+ exp.DateSub,
+ exp.DateTrunc,
+ exp.DiToDate,
+ exp.StrToDate,
+ exp.TimeStrToDate,
+ exp.TsOrDsToDate,
+ },
+ exp.DataType.Type.DATETIME: {
+ exp.CurrentDatetime,
+ exp.DatetimeAdd,
+ exp.DatetimeSub,
+ },
+ exp.DataType.Type.DOUBLE: {
+ exp.ApproxQuantile,
+ exp.Avg,
+ exp.Exp,
+ exp.Ln,
+ exp.Log,
+ exp.Log2,
+ exp.Log10,
+ exp.Pow,
+ exp.Quantile,
+ exp.Round,
+ exp.SafeDivide,
+ exp.Sqrt,
+ exp.Stddev,
+ exp.StddevPop,
+ exp.StddevSamp,
+ exp.Variance,
+ exp.VariancePop,
},
exp.DataType.Type.INT: {
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.Ceil,
+ exp.DateDiff,
+ exp.DatetimeDiff,
+ exp.Extract,
+ exp.TimestampDiff,
+ exp.TimeDiff,
+ exp.DateToDi,
+ exp.Floor,
+ exp.Levenshtein,
+ exp.StrPosition,
+ exp.TsOrDiToDi,
},
- exp.DataType.Type.SMALLINT: {
- exp.DataType.Type.INT,
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.DataType.Type.TIMESTAMP: {
+ exp.CurrentTime,
+ exp.CurrentTimestamp,
+ exp.StrToTime,
+ exp.TimeAdd,
+ exp.TimeStrToTime,
+ exp.TimeSub,
+ exp.TimestampAdd,
+ exp.TimestampSub,
+ exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
- exp.DataType.Type.SMALLINT,
- exp.DataType.Type.INT,
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.Day,
+ exp.Month,
+ exp.Week,
+ exp.Year,
},
- # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
- exp.DataType.Type.TIMESTAMPLTZ: set(),
- exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
- exp.DataType.Type.TIMESTAMP: {
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ exp.DataType.Type.VARCHAR: {
+ exp.ArrayConcat,
+ exp.Concat,
+ exp.ConcatWs,
+ exp.DateToDateStr,
+ exp.GroupConcat,
+ exp.Initcap,
+ exp.Lower,
+ exp.SafeConcat,
+ exp.Substring,
+ exp.TimeToStr,
+ exp.TimeToTimeStr,
+ exp.Trim,
+ exp.TsOrDsToDateStr,
+ exp.UnixToStr,
+ exp.UnixToTimeStr,
+ exp.Upper,
},
- exp.DataType.Type.DATETIME: {
- exp.DataType.Type.TIMESTAMP,
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ }
+
+ ANNOTATORS = {
+ **{
+ expr_type: lambda self, e: self._annotate_unary(e)
+ for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
- exp.DataType.Type.DATE: {
- exp.DataType.Type.DATETIME,
- exp.DataType.Type.TIMESTAMP,
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ **{
+ expr_type: lambda self, e: self._annotate_binary(e)
+ for expr_type in subclasses(exp.__name__, exp.Binary)
+ },
+ **{
+ expr_type: _annotate_with_type_lambda(data_type)
+ for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
+ for expr_type in expressions
},
+ exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
+ exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
+ exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
+ exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
+ exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
+ exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Literal: lambda self, e: self._annotate_literal(e),
+ exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
+ exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
- TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
+ # Specifies what types a given type can be coerced into (autofilled)
+ COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
- def __init__(self, schema=None, annotators=None, coerces_to=None):
+ def __init__(
+ self,
+ schema: Schema,
+ annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+ ) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
- def annotate(self, expression):
- if isinstance(expression, self.TRAVERSABLES):
- for scope in traverse_scope(expression):
- selects = {}
- for name, source in scope.sources.items():
- if not isinstance(source, Scope):
- continue
- if isinstance(source.expression, exp.UDTF):
- values = []
-
- if isinstance(source.expression, exp.Lateral):
- if isinstance(source.expression.this, exp.Explode):
- values = [source.expression.this.this]
- else:
- values = source.expression.expressions[0].expressions
-
- if not values:
- continue
-
- selects[name] = {
- alias: column
- for alias, column in zip(
- source.expression.alias_column_names,
- values,
- )
- }
+ def annotate(self, expression: E) -> E:
+ for scope in traverse_scope(expression):
+ selects = {}
+ for name, source in scope.sources.items():
+ if not isinstance(source, Scope):
+ continue
+ if isinstance(source.expression, exp.UDTF):
+ values = []
+
+ if isinstance(source.expression, exp.Lateral):
+ if isinstance(source.expression.this, exp.Explode):
+ values = [source.expression.this.this]
else:
- selects[name] = {
- select.alias_or_name: select for select in source.expression.selects
- }
- # First annotate the current scope's column references
- for col in scope.columns:
- if not col.table:
+ values = source.expression.expressions[0].expressions
+
+ if not values:
continue
- source = scope.sources.get(col.table)
- if isinstance(source, exp.Table):
- col.type = self.schema.get_column_type(source, col)
- elif source and col.table in selects and col.name in selects[col.table]:
- col.type = selects[col.table][col.name].type
- # Then (possibly) annotate the remaining expressions in the scope
- self._maybe_annotate(scope.expression)
+ selects[name] = {
+ alias: column
+ for alias, column in zip(
+ source.expression.alias_column_names,
+ values,
+ )
+ }
+ else:
+ selects[name] = {
+ select.alias_or_name: select for select in source.expression.selects
+ }
+
+ # First annotate the current scope's column references
+ for col in scope.columns:
+ if not col.table:
+ continue
+
+ source = scope.sources.get(col.table)
+ if isinstance(source, exp.Table):
+ col.type = self.schema.get_column_type(source, col)
+ elif source and col.table in selects and col.name in selects[col.table]:
+ col.type = selects[col.table][col.name].type
+
+ # Then (possibly) annotate the remaining expressions in the scope
+ self._maybe_annotate(scope.expression)
+
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
- def _maybe_annotate(self, expression):
+ def _maybe_annotate(self, expression: E) -> E:
if expression.type:
return expression # We've already inferred the expression's type
@@ -312,13 +290,15 @@ class TypeAnnotator:
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
- def _annotate_args(self, expression):
+ def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
- def _maybe_coerce(self, type1, type2):
+ def _maybe_coerce(
+ self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
+ ) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
@@ -330,9 +310,14 @@ class TypeAnnotator:
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to.get(type1, {}) else type1
+ return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
- def _annotate_binary(self, expression):
+ # Note: the following "no_type_check" decorators were added because mypy was yelling due
+ # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
+ # This is a known mypy issue: https://github.com/python/mypy/issues/3004
+
+ @t.no_type_check
+ def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left_type = expression.left.type.this
@@ -354,7 +339,8 @@ class TypeAnnotator:
return expression
- def _annotate_unary(self, expression):
+ @t.no_type_check
+ def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
@@ -364,7 +350,8 @@ class TypeAnnotator:
return expression
- def _annotate_literal(self, expression):
+ @t.no_type_check
+ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
@@ -374,13 +361,16 @@ class TypeAnnotator:
return expression
- def _annotate_with_type(self, expression, target_type):
+ @t.no_type_check
+ def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
expression.type = target_type
return self._annotate_args(expression)
- def _annotate_by_args(self, expression, *args, promote=False):
+ @t.no_type_check
+ def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
self._annotate_args(expression)
- expressions = []
+
+ expressions: t.List[exp.Expression] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index da2fce8..015b06a 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -26,7 +26,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
- node = exp.Concat(this=node.this, expression=node.expression)
+ node = exp.Concat(expressions=[node.left, node.right])
return node
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 27de9c7..cd8ba3b 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -32,7 +32,7 @@ def eliminate_joins(expression):
# Reverse the joins so we can remove chains of unused joins
for join in reversed(joins):
- alias = join.this.alias_or_name
+ alias = join.alias_or_name
if _should_eliminate_join(scope, join, alias):
join.pop()
scope.remove_source(alias)
@@ -126,7 +126,7 @@ def join_condition(join):
tuple[list[str], list[str], exp.Expression]:
Tuple of (source key, join key, remaining predicate)
"""
- name = join.this.alias_or_name
+ name = join.alias_or_name
on = (join.args.get("on") or exp.true()).copy()
source_key = []
join_key = []
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
index 5dfa4aa..79e3ed5 100644
--- a/sqlglot/optimizer/isolate_table_selects.py
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None):
source.replace(
exp.select("*")
.from_(
- alias(source, source.name or source.alias, table=True),
+ alias(source, source.alias_or_name, table=True),
copy=False,
)
.subquery(source.alias, copy=False)
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index f9c9664..fefe96e 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -145,7 +145,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
if not isinstance(from_or_join, exp.Join):
return False
- alias = from_or_join.this.alias_or_name
+ alias = from_or_join.alias_or_name
on = from_or_join.args.get("on")
if not on:
@@ -253,10 +253,6 @@ def _merge_joins(outer_scope, inner_scope, from_or_join):
"""
new_joins = []
- comma_joins = inner_scope.expression.args.get("from").expressions[1:]
- for subquery in comma_joins:
- new_joins.append(exp.Join(this=subquery, kind="CROSS"))
- outer_scope.add_source(subquery.alias_or_name, inner_scope.sources[subquery.alias_or_name])
joins = inner_scope.expression.args.get("joins") or []
for join in joins:
@@ -328,13 +324,12 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if source == from_or_join.alias_or_name:
break
- if set(exp.column_table_names(where.this)) <= sources:
+ if exp.column_table_names(where.this) <= sources:
from_or_join.on(where.this, copy=False)
from_or_join.set("on", from_or_join.args.get("on"))
return
expression.where(where.this, copy=False)
- expression.set("where", expression.args.get("where"))
def _merge_order(outer_scope, inner_scope):
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 4e0c3a1..d51276f 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -1,3 +1,7 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import exp
from sqlglot.helper import tsort
@@ -13,25 +17,28 @@ def optimize_joins(expression):
>>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql()
'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a'
"""
+
for select in expression.find_all(exp.Select):
references = {}
cross_joins = []
for join in select.args.get("joins", []):
- name = join.this.alias_or_name
- tables = other_table_names(join, name)
+ tables = other_table_names(join)
if tables:
for table in tables:
references[table] = references.get(table, []) + [join]
else:
- cross_joins.append((name, join))
+ cross_joins.append((join.alias_or_name, join))
for name, join in cross_joins:
for dep in references.get(name, []):
on = dep.args["on"]
if isinstance(on, exp.Connector):
+ if len(other_table_names(dep)) < 2:
+ continue
+
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
predicate.replace(exp.true())
@@ -47,17 +54,12 @@ def reorder_joins(expression):
Reorder joins by topological sort order based on predicate references.
"""
for from_ in expression.find_all(exp.From):
- head = from_.this
parent = from_.parent
- joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
- dag = {head.alias_or_name: []}
-
- for name, join in joins.items():
- dag[name] = other_table_names(join, name)
-
+ joins = {join.alias_or_name: join for join in parent.args.get("joins", [])}
+ dag = {name: other_table_names(join) for name, join in joins.items()}
parent.set(
"joins",
- [joins[name] for name in tsort(dag) if name != head.alias_or_name],
+ [joins[name] for name in tsort(dag) if name != from_.alias_or_name],
)
return expression
@@ -75,9 +77,6 @@ def normalize(expression):
return expression
-def other_table_names(join, exclude):
- return [
- name
- for name in (exp.column_table_names(join.args.get("on") or exp.true()))
- if name != exclude
- ]
+def other_table_names(join: exp.Join) -> t.Set[str]:
+ on = join.args.get("on")
+ return exp.column_table_names(on, join.alias_or_name) if on else set()
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index dbe33a2..abac63b 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -78,7 +78,7 @@ def optimize(
"schema": schema,
"dialect": dialect,
"isolate_tables": True, # needed for other optimizations to perform well
- "quote_identifiers": False, # this happens in canonicalize
+ "quote_identifiers": False,
**kwargs,
}
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index b89a82b..fb1662d 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -41,7 +41,7 @@ def pushdown_predicates(expression):
# joins should only pushdown into itself, not to other joins
# so we limit the selected sources to only itself
for join in select.args.get("joins") or []:
- name = join.this.alias_or_name
+ name = join.alias_or_name
pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count)
return expression
@@ -93,10 +93,10 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
pushdown_tables = set()
for a in predicates:
- a_tables = set(exp.column_table_names(a))
+ a_tables = exp.column_table_names(a)
for b in predicates:
- a_tables &= set(exp.column_table_names(b))
+ a_tables &= exp.column_table_names(b)
pushdown_tables.update(a_tables)
@@ -147,7 +147,7 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
tables = exp.column_table_names(predicate)
where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where)
- for table in tables:
+ for table in sorted(tables):
node, source = sources.get(table) or (None, None)
# if the predicate is in a where statement we can try to push it down
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 4a31171..aba9a7e 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -14,7 +14,7 @@ from sqlglot.schema import Schema, ensure_schema
def qualify_columns(
expression: exp.Expression,
- schema: dict | Schema,
+ schema: t.Dict | Schema,
expand_alias_refs: bool = True,
infer_schema: t.Optional[bool] = None,
) -> exp.Expression:
@@ -93,7 +93,7 @@ def _pop_table_column_aliases(derived_tables):
def _expand_using(scope, resolver):
joins = list(scope.find_all(exp.Join))
- names = {join.this.alias for join in joins}
+ names = {join.alias_or_name for join in joins}
ordered = [key for key in scope.selected_sources if key not in names]
# Mapping of automatically joined column names to an ordered set of source names (dict).
@@ -105,7 +105,7 @@ def _expand_using(scope, resolver):
if not using:
continue
- join_table = join.this.alias_or_name
+ join_table = join.alias_or_name
columns = {}
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index fcc5f26..9c931d6 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -91,11 +91,13 @@ def qualify_tables(
)
elif isinstance(source, Scope) and source.is_udtf:
udtf = source.expression
- table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name())
+ table_alias = udtf.args.get("alias") or exp.TableAlias(
+ this=exp.to_identifier(next_alias_name())
+ )
udtf.set("alias", table_alias)
if not table_alias.name:
- table_alias.set("this", next_alias_name())
+ table_alias.set("this", exp.to_identifier(next_alias_name()))
if isinstance(udtf, exp.Values) and not table_alias.columns:
for i, e in enumerate(udtf.expressions[0].expressions):
table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 9ffb4d6..aa56b83 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -620,7 +620,7 @@ def _traverse_tables(scope):
table_name = expression.name
source_name = expression.alias_or_name
- if table_name in scope.sources:
+ if table_name in scope.sources and not expression.db:
# This is a reference to a parent source (e.g. a CTE), not an actual table, unless
# it is pivoted, because then we get back a new table and hence a new source.
pivots = expression.args.get("pivots")