From 8fe30fd23dc37ec3516e530a86d1c4b604e71241 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 10 Dec 2023 11:46:01 +0100 Subject: Merging upstream version 20.1.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/annotate_types.py | 110 ++++++++++++++++------------- sqlglot/optimizer/canonicalize.py | 85 ++++++++++++++++++---- sqlglot/optimizer/merge_subqueries.py | 4 +- sqlglot/optimizer/normalize_identifiers.py | 6 +- sqlglot/optimizer/optimizer.py | 4 +- sqlglot/optimizer/qualify_columns.py | 47 ++++++++---- sqlglot/optimizer/qualify_tables.py | 15 ++-- sqlglot/optimizer/scope.py | 2 + sqlglot/optimizer/simplify.py | 73 ++++++++++++------- 9 files changed, 235 insertions(+), 111 deletions(-) (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 69d4567..7b990f1 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,12 +1,18 @@ from __future__ import annotations -import datetime import functools import typing as t from sqlglot import exp from sqlglot._typing import E -from sqlglot.helper import ensure_list, seq_get, subclasses +from sqlglot.helper import ( + ensure_list, + is_date_unit, + is_iso_date, + is_iso_datetime, + seq_get, + subclasses, +) from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema @@ -20,10 +26,6 @@ if t.TYPE_CHECKING: ] -# Interval units that operate on date components -DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} - - def annotate_types( expression: E, schema: t.Optional[t.Dict | Schema] = None, @@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type return lambda self, e: self._annotate_with_type(e, data_type) -def _is_iso_date(text: str) -> bool: - try: - datetime.date.fromisoformat(text) - return True - except ValueError: - return False - - -def _is_iso_datetime(text: str) -> bool: - try: - datetime.datetime.fromisoformat(text) - return True - except ValueError: - return False - - -def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: +def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: date_text = l.name - unit = r.text("unit").lower() - - is_iso_date = _is_iso_date(date_text) + is_iso_date_ = is_iso_date(date_text) - if is_iso_date and unit in DATE_UNITS: - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) + if is_iso_date_ and is_date_unit(unit): return exp.DataType.Type.DATE # An ISO date is also an ISO datetime, but not vice versa - if is_iso_date or _is_iso_datetime(date_text): - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) + if is_iso_date_ or is_iso_datetime(date_text): return exp.DataType.Type.DATETIME return exp.DataType.Type.UNKNOWN -def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: - unit = r.text("unit").lower() - if unit not in DATE_UNITS: +def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: + if not is_date_unit(unit): return exp.DataType.Type.DATETIME return l.type.this if l.type else exp.DataType.Type.UNKNOWN @@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Date, exp.DateFromParts, exp.DateStrToDate, - exp.DateTrunc, exp.DiToDate, exp.StrToDate, exp.TimeStrToDate, @@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DataType.Type.DOUBLE: { exp.ApproxQuantile, exp.Avg, + exp.Div, exp.Exp, exp.Ln, exp.Log, @@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): }, exp.DataType.Type.INT: { exp.Ceil, - exp.DateDiff, exp.DatetimeDiff, + exp.DateDiff, exp.Extract, exp.TimestampDiff, exp.TimeDiff, @@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.GroupConcat, exp.Initcap, exp.Lower, - exp.SafeConcat, - exp.SafeDPipe, exp.Substring, exp.TimeToStr, exp.TimeToTimeStr, @@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): for data_type, expressions in TYPE_TO_EXPRESSIONS.items() for expr_type in expressions }, + exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), @@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), - exp.DateAdd: lambda self, e: self._annotate_dateadd(e), - exp.DateSub: lambda self, e: self._annotate_dateadd(e), + exp.DateAdd: lambda self, e: self._annotate_timeunit(e), + exp.DateSub: lambda self, e: self._annotate_timeunit(e), + exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Div: lambda self, e: self._annotate_div(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), @@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), @@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator): BINARY_COERCIONS: BinaryCoercions = { **swap_all( { - (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval + (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( + l, r.args.get("unit") + ) for t in exp.DataType.TEXT_TYPES } ), **swap_all( { - (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, + # text + numeric will yield the numeric type to match most dialects' semantics + (text, numeric): lambda l, r: t.cast( + exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type + ) + for text in exp.DataType.TEXT_TYPES + for numeric in exp.DataType.NUMERIC_TYPES + } + ), + **swap_all( + { + (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( + l, r.args.get("unit") + ), } ), } @@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression - def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: + def _annotate_timeunit( + self, expression: exp.TimeUnit | exp.DateTrunc + ) -> exp.TimeUnit | exp.DateTrunc: self._annotate_args(expression) if expression.this.type.this in exp.DataType.TEXT_TYPES: - datatype = _coerce_literal_and_interval(expression.this, expression.interval()) - elif ( - expression.this.type.is_type(exp.DataType.Type.DATE) - and expression.text("unit").lower() not in DATE_UNITS - ): - datatype = exp.DataType.Type.DATETIME + datatype = _coerce_date_literal(expression.this, expression.unit) + elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: + datatype = _coerce_date(expression.this, expression.unit) else: - datatype = expression.this.type + datatype = exp.DataType.Type.UNKNOWN self._set_type(expression, datatype) return expression @@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, exp.DataType.Type.UNKNOWN) return expression + + def _annotate_div(self, expression: exp.Div) -> exp.Div: + self._annotate_args(expression) + + left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore + + if ( + expression.args.get("typed") + and left_type in exp.DataType.INTEGER_TYPES + and right_type in exp.DataType.INTEGER_TYPES + ): + self._set_type(expression, exp.DataType.Type.BIGINT) + else: + self._set_type(expression, self._maybe_coerce(left_type, right_type)) + + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index fc5c348..faf18c6 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,8 +1,10 @@ from __future__ import annotations import itertools +import typing as t from sqlglot import exp +from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime def canonicalize(expression: exp.Expression) -> exp.Expression: @@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression = replace_date_funcs(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) - expression = ensure_bool_predicates(expression) + expression = ensure_bools(expression, _replace_int_predicate) expression = remove_ascending_order(expression) return expression @@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression: return node +COERCIBLE_DATE_OPS = ( + exp.Add, + exp.Sub, + exp.EQ, + exp.NEQ, + exp.GT, + exp.GTE, + exp.LT, + exp.LTE, + exp.NullSafeEQ, + exp.NullSafeNEQ, +) + + def coerce_type(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Binary): + if isinstance(node, COERCIBLE_DATE_OPS): _coerce_date(node.left, node.right) elif isinstance(node, exp.Between): _coerce_date(node.this, node.args["low"]) @@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression: *exp.DataType.TEMPORAL_TYPES ): _replace_cast(node.expression, exp.DataType.Type.DATETIME) + elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): + _coerce_timeunit_arg(node.this, node.unit) + elif isinstance(node, exp.DateDiff): + _coerce_datediff_args(node) return node @@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: return expression -def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: +def ensure_bools( + expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] +) -> exp.Expression: if isinstance(expression, exp.Connector): - _replace_int_predicate(expression.left) - _replace_int_predicate(expression.right) - - elif isinstance(expression, (exp.Where, exp.Having)) or ( + replace_func(expression.left) + replace_func(expression.right) + elif isinstance(expression, exp.Not): + replace_func(expression.this) # We can't replace num in CASE x WHEN num ..., because it's not the full predicate - isinstance(expression, exp.If) - and not (isinstance(expression.parent, exp.Case) and expression.parent.this) + elif isinstance(expression, exp.If) and not ( + isinstance(expression.parent, exp.Case) and expression.parent.this ): - _replace_int_predicate(expression.this) + replace_func(expression.this) + elif isinstance(expression, (exp.Where, exp.Having)): + replace_func(expression.this) return expression @@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression: def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): + if isinstance(b, exp.Interval): + a = _coerce_timeunit_arg(a, b.unit) if ( a.type and a.type.this == exp.DataType.Type.DATE and b.type - and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) + and b.type.this + not in ( + exp.DataType.Type.DATE, + exp.DataType.Type.INTERVAL, + ) ): _replace_cast(b, exp.DataType.Type.DATE) +def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: + if not arg.type: + return arg + + if arg.type.this in exp.DataType.TEXT_TYPES: + date_text = arg.name + is_iso_date_ = is_iso_date(date_text) + + if is_iso_date_ and is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date_ or is_iso_datetime(date_text): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): + return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) + + return arg + + +def _coerce_datediff_args(node: exp.DateDiff) -> None: + for e in (node.this, node.expression): + if e.type.this not in exp.DataType.TEMPORAL_TYPES: + e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) + + def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: node.replace(exp.cast(node.copy(), to=to)) +# this was originally designed for presto, there is a similar transform for tsql +# this is different in that it only operates on int types, this is because +# presto has a boolean type whereas tsql doesn't (people use bits) +# with y as (select true as x) select x = 0 FROM y -- illegal presto query def _replace_int_predicate(expression: exp.Expression) -> None: if isinstance(expression, exp.Coalesce): for _, child in expression.iter_expressions(): _replace_int_predicate(child) elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: - expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) + expression.replace(expression.neq(0)) diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index b0b2b3d..a74bea7 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -186,13 +186,13 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): and not ( isinstance(from_or_join, exp.Join) and inner_select.args.get("where") - and from_or_join.side in {"FULL", "LEFT", "RIGHT"} + and from_or_join.side in ("FULL", "LEFT", "RIGHT") ) and not ( isinstance(from_or_join, exp.From) and inner_select.args.get("where") and any( - j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) + j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", []) ) ) and not _outer_select_joins_on_inner_select_join() diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 154256e..3361a33 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -13,7 +13,7 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression: +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... @@ -48,11 +48,11 @@ def normalize_identifiers(expression, dialect=None): Returns: The transformed expression. """ + dialect = Dialect.get_or_raise(dialect) + if isinstance(expression, str): expression = exp.parse_identifier(expression, dialect=dialect) - dialect = Dialect.get_or_raise(dialect) - def _normalize(node: E) -> E: if not node.meta.get("case_sensitive"): exp.replace_children(node, _normalize) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index abac63b..1c96e95 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -42,8 +42,8 @@ RULES = ( def optimize( expression: str | exp.Expression, schema: t.Optional[dict | Schema] = None, - db: t.Optional[str] = None, - catalog: t.Optional[str] = None, + db: t.Optional[str | exp.Identifier] = None, + catalog: t.Optional[str | exp.Identifier] = None, dialect: DialectType = None, rules: t.Sequence[t.Callable] = RULES, **kwargs, diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index b06ea1d..742cdf5 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -8,7 +8,7 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get -from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -58,7 +58,7 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables, pseudocolumns) - _qualify_outputs(scope) + qualify_outputs(scope) _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None: ordereds = order.expressions for ordered, new_expression in zip( ordereds, - _expand_positional_references(scope, (o.this for o in ordereds)), + _expand_positional_references(scope, (o.this for o in ordereds), alias=True), ): for agg in ordered.find_all(exp.AggFunc): for col in agg.find_all(exp.Column): @@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None: ) -def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: - new_nodes = [] +def _expand_positional_references( + scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False +) -> t.List[exp.Expression]: + new_nodes: t.List[exp.Expression] = [] for node in expressions: if node.is_int: - select = _select_by_pos(scope, t.cast(exp.Literal, node)).this + select = _select_by_pos(scope, t.cast(exp.Literal, node)) - if isinstance(select, exp.Literal): - new_nodes.append(node) + if alias: + new_nodes.append(exp.column(select.args["alias"].copy())) else: - new_nodes.append(select.copy()) - scope.clear_cache() + select = select.this + + if isinstance(select, exp.Literal): + new_nodes.append(node) + else: + new_nodes.append(select.copy()) else: new_nodes.append(node) @@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None: if column_table: column.set("table", column_table) elif column_table not in scope.sources and ( - not scope.parent or column_table not in scope.parent.sources + not scope.parent + or column_table not in scope.parent.sources + or not scope.is_correlated_subquery ): # structs are used like tables (e.g. "struct"."field"), so they need to be qualified # separately and represented as dot(dot(...(., field1), field2, ...)) @@ -381,15 +389,18 @@ def _expand_stars( columns = [name for name in columns if name.upper() not in pseudocolumns] if columns and "*" not in columns: + table_id = id(table) + columns_to_exclude = except_columns.get(table_id) or set() + if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: implicit_columns = [col for col in columns if col not in pivot_columns] new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) for name in implicit_columns + pivot_output_columns + if name not in columns_to_exclude ) continue - table_id = id(table) for name in columns: if name in using_column_tables and table in using_column_tables[name]: if name in coalesced_columns: @@ -406,7 +417,7 @@ def _expand_stars( copy=False, ) ) - elif name not in except_columns.get(table_id, set()): + elif name not in columns_to_exclude: alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table=table) new_selections.append( @@ -448,10 +459,16 @@ def _add_replace_columns( replace_columns[id(table)] = columns -def _qualify_outputs(scope: Scope) -> None: +def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: """Ensure all output columns are aliased""" - new_selections = [] + if isinstance(scope_or_expression, exp.Expression): + scope = build_scope(scope_or_expression) + if not isinstance(scope, Scope): + return + else: + scope = scope_or_expression + new_selections = [] for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 3a43e8f..57ecabe 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import itertools import typing as t from sqlglot import alias, exp from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema @@ -10,9 +13,10 @@ from sqlglot.schema import Schema def qualify_tables( expression: E, - db: t.Optional[str] = None, - catalog: t.Optional[str] = None, + db: t.Optional[str | exp.Identifier] = None, + catalog: t.Optional[str | exp.Identifier] = None, schema: t.Optional[Schema] = None, + dialect: DialectType = None, ) -> E: """ Rewrite sqlglot AST to have fully qualified tables. Join constructs such as @@ -33,11 +37,14 @@ def qualify_tables( db: Database name catalog: Catalog name schema: A schema to populate + dialect: The dialect to parse catalog and schema into. Returns: The qualified expression. """ next_alias_name = name_sequence("_q_") + db = exp.parse_identifier(db, dialect=dialect) if db else None + catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): @@ -61,9 +68,9 @@ def qualify_tables( if isinstance(source, exp.Table): if isinstance(source.this, exp.Identifier): if not source.args.get("db"): - source.set("db", exp.to_identifier(db)) + source.set("db", db) if not source.args.get("catalog") and source.args.get("db"): - source.set("catalog", exp.to_identifier(catalog)) + source.set("catalog", catalog) if not source.alias: # Mutates the source by attaching an alias to it diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 4af5b49..b7e527e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import logging import typing as t diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index af03332..d4e2e60 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -507,6 +507,9 @@ def simplify_literals(expression, root=True): return exp.Literal.number(value[1:]) return exp.Literal.number(f"-{value}") + if type(expression) in INVERSE_DATE_OPS: + return _simplify_binary(expression, expression.this, expression.interval()) or expression + return expression @@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b): return exp.null() if a.is_number and b.is_number: - a = int(a.name) if a.is_int else Decimal(a.name) - b = int(b.name) if b.is_int else Decimal(b.name) + num_a = int(a.name) if a.is_int else Decimal(a.name) + num_b = int(b.name) if b.is_int else Decimal(b.name) if isinstance(expression, exp.Add): - return exp.Literal.number(a + b) - if isinstance(expression, exp.Sub): - return exp.Literal.number(a - b) + return exp.Literal.number(num_a + num_b) if isinstance(expression, exp.Mul): - return exp.Literal.number(a * b) + return exp.Literal.number(num_a * num_b) + + # We only simplify Sub, Div if a and b have the same parent because they're not associative + if isinstance(expression, exp.Sub): + return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None if isinstance(expression, exp.Div): # engines have differing int div behavior so intdiv is not safe - if isinstance(a, int) and isinstance(b, int): + if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: return None - return exp.Literal.number(a / b) + return exp.Literal.number(num_a / num_b) - boolean = eval_boolean(expression, a, b) + boolean = eval_boolean(expression, num_a, num_b) if boolean: return boolean @@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b): elif _is_date_literal(a) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) if a and b: - if isinstance(expression, exp.Add): + if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): return date_literal(a + b) - if isinstance(expression, exp.Sub): + if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): return date_literal(a - b) elif isinstance(a, exp.Interval) and _is_date_literal(b): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): return date_literal(a + b) + elif _is_date_literal(a) and _is_date_literal(b): + if isinstance(expression, exp.Predicate): + a, b = extract_date(a), extract_date(b) + boolean = eval_boolean(expression, a, b) + if boolean: + return boolean return None @@ -590,6 +601,11 @@ def simplify_parens(expression): return expression +NONNULL_CONSTANTS = ( + exp.Literal, + exp.Boolean, +) + CONSTANTS = ( exp.Literal, exp.Boolean, @@ -597,11 +613,19 @@ CONSTANTS = ( ) +def _is_nonnull_constant(expression: exp.Expression) -> bool: + return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) + + +def _is_constant(expression: exp.Expression) -> bool: + return isinstance(expression, CONSTANTS) or _is_date_literal(expression) + + def simplify_coalesce(expression): # COALESCE(x) -> x if ( isinstance(expression, exp.Coalesce) - and not expression.expressions + and (not expression.expressions or _is_nonnull_constant(expression.this)) # COALESCE is also used as a Spark partitioning hint and not isinstance(expression.parent, exp.Hint) ): @@ -621,12 +645,12 @@ def simplify_coalesce(expression): # This transformation is valid for non-constants, # but it really only does anything if they are both constants. - if not isinstance(other, CONSTANTS): + if not _is_constant(other): return expression # Find the first constant arg for arg_index, arg in enumerate(coalesce.expressions): - if isinstance(arg, CONSTANTS): + if _is_constant(other): break else: return expression @@ -656,7 +680,6 @@ def simplify_coalesce(expression): CONCATS = (exp.Concat, exp.DPipe) -SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) def simplify_concat(expression): @@ -672,10 +695,15 @@ def simplify_concat(expression): sep_expr, *expressions = expression.expressions sep = sep_expr.name concat_type = exp.ConcatWs + args = {} else: expressions = expression.expressions sep = "" - concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + concat_type = exp.Concat + args = { + "safe": expression.args.get("safe"), + "coalesce": expression.args.get("coalesce"), + } new_args = [] for is_string_group, group in itertools.groupby( @@ -692,7 +720,7 @@ def simplify_concat(expression): if concat_type is exp.ConcatWs: new_args = [sep_expr] + new_args - return concat_type(expressions=new_args) + return concat_type(expressions=new_args, **args) def simplify_conditionals(expression): @@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: if isinstance(cast, exp.Cast): to = cast.to - elif isinstance(cast, exp.TsOrDsToDate): + elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): to = exp.DataType.build(exp.DataType.Type.DATE) else: return None @@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool: def extract_interval(expression): - n = int(expression.name) - unit = expression.text("unit").lower() - try: + n = int(expression.name) + unit = expression.text("unit").lower() return interval(unit, n) - except (UnsupportedUnit, ModuleNotFoundError): + except (UnsupportedUnit, ModuleNotFoundError, ValueError): return None @@ -1099,8 +1126,6 @@ GEN_MAP = { exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", exp.Div: lambda e: _binary(e, "/"), exp.Dot: lambda e: _binary(e, "."), - exp.DPipe: lambda e: _binary(e, "||"), - exp.SafeDPipe: lambda e: _binary(e, "||"), exp.EQ: lambda e: _binary(e, "="), exp.GT: lambda e: _binary(e, ">"), exp.GTE: lambda e: _binary(e, ">="), -- cgit v1.2.3