summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/optimizer
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py110
-rw-r--r--sqlglot/optimizer/canonicalize.py85
-rw-r--r--sqlglot/optimizer/merge_subqueries.py4
-rw-r--r--sqlglot/optimizer/normalize_identifiers.py6
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py47
-rw-r--r--sqlglot/optimizer/qualify_tables.py15
-rw-r--r--sqlglot/optimizer/scope.py2
-rw-r--r--sqlglot/optimizer/simplify.py73
9 files changed, 235 insertions, 111 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 69d4567..7b990f1 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,12 +1,18 @@
from __future__ import annotations
-import datetime
import functools
import typing as t
from sqlglot import exp
from sqlglot._typing import E
-from sqlglot.helper import ensure_list, seq_get, subclasses
+from sqlglot.helper import (
+ ensure_list,
+ is_date_unit,
+ is_iso_date,
+ is_iso_datetime,
+ seq_get,
+ subclasses,
+)
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@@ -20,10 +26,6 @@ if t.TYPE_CHECKING:
]
-# Interval units that operate on date components
-DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
-
-
def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
@@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)
-def _is_iso_date(text: str) -> bool:
- try:
- datetime.date.fromisoformat(text)
- return True
- except ValueError:
- return False
-
-
-def _is_iso_datetime(text: str) -> bool:
- try:
- datetime.datetime.fromisoformat(text)
- return True
- except ValueError:
- return False
-
-
-def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
date_text = l.name
- unit = r.text("unit").lower()
-
- is_iso_date = _is_iso_date(date_text)
+ is_iso_date_ = is_iso_date(date_text)
- if is_iso_date and unit in DATE_UNITS:
- l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
+ if is_iso_date_ and is_date_unit(unit):
return exp.DataType.Type.DATE
# An ISO date is also an ISO datetime, but not vice versa
- if is_iso_date or _is_iso_datetime(date_text):
- l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
+ if is_iso_date_ or is_iso_datetime(date_text):
return exp.DataType.Type.DATETIME
return exp.DataType.Type.UNKNOWN
-def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
- unit = r.text("unit").lower()
- if unit not in DATE_UNITS:
+def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
+ if not is_date_unit(unit):
return exp.DataType.Type.DATETIME
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
@@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Date,
exp.DateFromParts,
exp.DateStrToDate,
- exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
@@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
+ exp.Div,
exp.Exp,
exp.Ln,
exp.Log,
@@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
},
exp.DataType.Type.INT: {
exp.Ceil,
- exp.DateDiff,
exp.DatetimeDiff,
+ exp.DateDiff,
exp.Extract,
exp.TimestampDiff,
exp.TimeDiff,
@@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.GroupConcat,
exp.Initcap,
exp.Lower,
- exp.SafeConcat,
- exp.SafeDPipe,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
@@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
+ exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
@@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
- exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
- exp.DateSub: lambda self, e: self._annotate_dateadd(e),
+ exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
+ exp.DateSub: lambda self, e: self._annotate_timeunit(e),
+ exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Div: lambda self, e: self._annotate_div(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
@@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
@@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
BINARY_COERCIONS: BinaryCoercions = {
**swap_all(
{
- (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
+ (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
+ l, r.args.get("unit")
+ )
for t in exp.DataType.TEXT_TYPES
}
),
**swap_all(
{
- (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
+ # text + numeric will yield the numeric type to match most dialects' semantics
+ (text, numeric): lambda l, r: t.cast(
+ exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type
+ )
+ for text in exp.DataType.TEXT_TYPES
+ for numeric in exp.DataType.NUMERIC_TYPES
+ }
+ ),
+ **swap_all(
+ {
+ (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
+ l, r.args.get("unit")
+ ),
}
),
}
@@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
- def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
+ def _annotate_timeunit(
+ self, expression: exp.TimeUnit | exp.DateTrunc
+ ) -> exp.TimeUnit | exp.DateTrunc:
self._annotate_args(expression)
if expression.this.type.this in exp.DataType.TEXT_TYPES:
- datatype = _coerce_literal_and_interval(expression.this, expression.interval())
- elif (
- expression.this.type.is_type(exp.DataType.Type.DATE)
- and expression.text("unit").lower() not in DATE_UNITS
- ):
- datatype = exp.DataType.Type.DATETIME
+ datatype = _coerce_date_literal(expression.this, expression.unit)
+ elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
+ datatype = _coerce_date(expression.this, expression.unit)
else:
- datatype = expression.this.type
+ datatype = exp.DataType.Type.UNKNOWN
self._set_type(expression, datatype)
return expression
@@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.UNKNOWN)
return expression
+
+ def _annotate_div(self, expression: exp.Div) -> exp.Div:
+ self._annotate_args(expression)
+
+ left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
+
+ if (
+ expression.args.get("typed")
+ and left_type in exp.DataType.INTEGER_TYPES
+ and right_type in exp.DataType.INTEGER_TYPES
+ ):
+ self._set_type(expression, exp.DataType.Type.BIGINT)
+ else:
+ self._set_type(expression, self._maybe_coerce(left_type, right_type))
+
+ return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index fc5c348..faf18c6 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -1,8 +1,10 @@
from __future__ import annotations
import itertools
+import typing as t
from sqlglot import exp
+from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime
def canonicalize(expression: exp.Expression) -> exp.Expression:
@@ -20,7 +22,7 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression = replace_date_funcs(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
- expression = ensure_bool_predicates(expression)
+ expression = ensure_bools(expression, _replace_int_predicate)
expression = remove_ascending_order(expression)
return expression
@@ -40,8 +42,22 @@ def replace_date_funcs(node: exp.Expression) -> exp.Expression:
return node
+COERCIBLE_DATE_OPS = (
+ exp.Add,
+ exp.Sub,
+ exp.EQ,
+ exp.NEQ,
+ exp.GT,
+ exp.GTE,
+ exp.LT,
+ exp.LTE,
+ exp.NullSafeEQ,
+ exp.NullSafeNEQ,
+)
+
+
def coerce_type(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Binary):
+ if isinstance(node, COERCIBLE_DATE_OPS):
_coerce_date(node.left, node.right)
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
@@ -49,6 +65,10 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
*exp.DataType.TEMPORAL_TYPES
):
_replace_cast(node.expression, exp.DataType.Type.DATETIME)
+ elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)):
+ _coerce_timeunit_arg(node.this, node.unit)
+ elif isinstance(node, exp.DateDiff):
+ _coerce_datediff_args(node)
return node
@@ -64,17 +84,21 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
return expression
-def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression:
+def ensure_bools(
+ expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None]
+) -> exp.Expression:
if isinstance(expression, exp.Connector):
- _replace_int_predicate(expression.left)
- _replace_int_predicate(expression.right)
-
- elif isinstance(expression, (exp.Where, exp.Having)) or (
+ replace_func(expression.left)
+ replace_func(expression.right)
+ elif isinstance(expression, exp.Not):
+ replace_func(expression.this)
# We can't replace num in CASE x WHEN num ..., because it's not the full predicate
- isinstance(expression, exp.If)
- and not (isinstance(expression.parent, exp.Case) and expression.parent.this)
+ elif isinstance(expression, exp.If) and not (
+ isinstance(expression.parent, exp.Case) and expression.parent.this
):
- _replace_int_predicate(expression.this)
+ replace_func(expression.this)
+ elif isinstance(expression, (exp.Where, exp.Having)):
+ replace_func(expression.this)
return expression
@@ -89,22 +113,59 @@ def remove_ascending_order(expression: exp.Expression) -> exp.Expression:
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
+ if isinstance(b, exp.Interval):
+ a = _coerce_timeunit_arg(a, b.unit)
if (
a.type
and a.type.this == exp.DataType.Type.DATE
and b.type
- and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL)
+ and b.type.this
+ not in (
+ exp.DataType.Type.DATE,
+ exp.DataType.Type.INTERVAL,
+ )
):
_replace_cast(b, exp.DataType.Type.DATE)
+def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression:
+ if not arg.type:
+ return arg
+
+ if arg.type.this in exp.DataType.TEXT_TYPES:
+ date_text = arg.name
+ is_iso_date_ = is_iso_date(date_text)
+
+ if is_iso_date_ and is_date_unit(unit):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE))
+
+ # An ISO date is also an ISO datetime, but not vice versa
+ if is_iso_date_ or is_iso_datetime(date_text):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
+
+ elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit):
+ return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME))
+
+ return arg
+
+
+def _coerce_datediff_args(node: exp.DateDiff) -> None:
+ for e in (node.this, node.expression):
+ if e.type.this not in exp.DataType.TEMPORAL_TYPES:
+ e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME))
+
+
def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None:
node.replace(exp.cast(node.copy(), to=to))
+# this was originally designed for presto, there is a similar transform for tsql
+# this is different in that it only operates on int types, this is because
+# presto has a boolean type whereas tsql doesn't (people use bits)
+# with y as (select true as x) select x = 0 FROM y -- illegal presto query
def _replace_int_predicate(expression: exp.Expression) -> None:
if isinstance(expression, exp.Coalesce):
for _, child in expression.iter_expressions():
_replace_int_predicate(child)
elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES:
- expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0)))
+ expression.replace(expression.neq(0))
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index b0b2b3d..a74bea7 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -186,13 +186,13 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
and not (
isinstance(from_or_join, exp.Join)
and inner_select.args.get("where")
- and from_or_join.side in {"FULL", "LEFT", "RIGHT"}
+ and from_or_join.side in ("FULL", "LEFT", "RIGHT")
)
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
and any(
- j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
+ j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", [])
)
)
and not _outer_select_joins_on_inner_select_join()
diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py
index 154256e..3361a33 100644
--- a/sqlglot/optimizer/normalize_identifiers.py
+++ b/sqlglot/optimizer/normalize_identifiers.py
@@ -13,7 +13,7 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
@t.overload
-def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Expression:
+def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier:
...
@@ -48,11 +48,11 @@ def normalize_identifiers(expression, dialect=None):
Returns:
The transformed expression.
"""
+ dialect = Dialect.get_or_raise(dialect)
+
if isinstance(expression, str):
expression = exp.parse_identifier(expression, dialect=dialect)
- dialect = Dialect.get_or_raise(dialect)
-
def _normalize(node: E) -> E:
if not node.meta.get("case_sensitive"):
exp.replace_children(node, _normalize)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index abac63b..1c96e95 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -42,8 +42,8 @@ RULES = (
def optimize(
expression: str | exp.Expression,
schema: t.Optional[dict | Schema] = None,
- db: t.Optional[str] = None,
- catalog: t.Optional[str] = None,
+ db: t.Optional[str | exp.Identifier] = None,
+ catalog: t.Optional[str | exp.Identifier] = None,
dialect: DialectType = None,
rules: t.Sequence[t.Callable] = RULES,
**kwargs,
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index b06ea1d..742cdf5 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -8,7 +8,7 @@ from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
-from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -58,7 +58,7 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
- _qualify_outputs(scope)
+ qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
- _expand_positional_references(scope, (o.this for o in ordereds)),
+ _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
):
for agg in ordered.find_all(exp.AggFunc):
for col in agg.find_all(exp.Column):
@@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
)
-def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
- new_nodes = []
+def _expand_positional_references(
+ scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
+) -> t.List[exp.Expression]:
+ new_nodes: t.List[exp.Expression] = []
for node in expressions:
if node.is_int:
- select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
+ select = _select_by_pos(scope, t.cast(exp.Literal, node))
- if isinstance(select, exp.Literal):
- new_nodes.append(node)
+ if alias:
+ new_nodes.append(exp.column(select.args["alias"].copy()))
else:
- new_nodes.append(select.copy())
- scope.clear_cache()
+ select = select.this
+
+ if isinstance(select, exp.Literal):
+ new_nodes.append(node)
+ else:
+ new_nodes.append(select.copy())
else:
new_nodes.append(node)
@@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources and (
- not scope.parent or column_table not in scope.parent.sources
+ not scope.parent
+ or column_table not in scope.parent.sources
+ or not scope.is_correlated_subquery
):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@@ -381,15 +389,18 @@ def _expand_stars(
columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
+ table_id = id(table)
+ columns_to_exclude = except_columns.get(table_id) or set()
+
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
+ if name not in columns_to_exclude
)
continue
- table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
@@ -406,7 +417,7 @@ def _expand_stars(
copy=False,
)
)
- elif name not in except_columns.get(table_id, set()):
+ elif name not in columns_to_exclude:
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table=table)
new_selections.append(
@@ -448,10 +459,16 @@ def _add_replace_columns(
replace_columns[id(table)] = columns
-def _qualify_outputs(scope: Scope) -> None:
+def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
"""Ensure all output columns are aliased"""
- new_selections = []
+ if isinstance(scope_or_expression, exp.Expression):
+ scope = build_scope(scope_or_expression)
+ if not isinstance(scope, Scope):
+ return
+ else:
+ scope = scope_or_expression
+ new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
):
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 3a43e8f..57ecabe 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
import itertools
import typing as t
from sqlglot import alias, exp
from sqlglot._typing import E
+from sqlglot.dialects.dialect import DialectType
from sqlglot.helper import csv_reader, name_sequence
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema
@@ -10,9 +13,10 @@ from sqlglot.schema import Schema
def qualify_tables(
expression: E,
- db: t.Optional[str] = None,
- catalog: t.Optional[str] = None,
+ db: t.Optional[str | exp.Identifier] = None,
+ catalog: t.Optional[str | exp.Identifier] = None,
schema: t.Optional[Schema] = None,
+ dialect: DialectType = None,
) -> E:
"""
Rewrite sqlglot AST to have fully qualified tables. Join constructs such as
@@ -33,11 +37,14 @@ def qualify_tables(
db: Database name
catalog: Catalog name
schema: A schema to populate
+ dialect: The dialect to parse catalog and schema into.
Returns:
The qualified expression.
"""
next_alias_name = name_sequence("_q_")
+ db = exp.parse_identifier(db, dialect=dialect) if db else None
+ catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None
for scope in traverse_scope(expression):
for derived_table in itertools.chain(scope.ctes, scope.derived_tables):
@@ -61,9 +68,9 @@ def qualify_tables(
if isinstance(source, exp.Table):
if isinstance(source.this, exp.Identifier):
if not source.args.get("db"):
- source.set("db", exp.to_identifier(db))
+ source.set("db", db)
if not source.args.get("catalog") and source.args.get("db"):
- source.set("catalog", exp.to_identifier(catalog))
+ source.set("catalog", catalog)
if not source.alias:
# Mutates the source by attaching an alias to it
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 4af5b49..b7e527e 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import itertools
import logging
import typing as t
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index af03332..d4e2e60 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -507,6 +507,9 @@ def simplify_literals(expression, root=True):
return exp.Literal.number(value[1:])
return exp.Literal.number(f"-{value}")
+ if type(expression) in INVERSE_DATE_OPS:
+ return _simplify_binary(expression, expression.this, expression.interval()) or expression
+
return expression
@@ -530,22 +533,24 @@ def _simplify_binary(expression, a, b):
return exp.null()
if a.is_number and b.is_number:
- a = int(a.name) if a.is_int else Decimal(a.name)
- b = int(b.name) if b.is_int else Decimal(b.name)
+ num_a = int(a.name) if a.is_int else Decimal(a.name)
+ num_b = int(b.name) if b.is_int else Decimal(b.name)
if isinstance(expression, exp.Add):
- return exp.Literal.number(a + b)
- if isinstance(expression, exp.Sub):
- return exp.Literal.number(a - b)
+ return exp.Literal.number(num_a + num_b)
if isinstance(expression, exp.Mul):
- return exp.Literal.number(a * b)
+ return exp.Literal.number(num_a * num_b)
+
+ # We only simplify Sub, Div if a and b have the same parent because they're not associative
+ if isinstance(expression, exp.Sub):
+ return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
if isinstance(expression, exp.Div):
# engines have differing int div behavior so intdiv is not safe
- if isinstance(a, int) and isinstance(b, int):
+ if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
return None
- return exp.Literal.number(a / b)
+ return exp.Literal.number(num_a / num_b)
- boolean = eval_boolean(expression, a, b)
+ boolean = eval_boolean(expression, num_a, num_b)
if boolean:
return boolean
@@ -557,15 +562,21 @@ def _simplify_binary(expression, a, b):
elif _is_date_literal(a) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
if a and b:
- if isinstance(expression, exp.Add):
+ if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
return date_literal(a + b)
- if isinstance(expression, exp.Sub):
+ if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
return date_literal(a - b)
elif isinstance(a, exp.Interval) and _is_date_literal(b):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
+ elif _is_date_literal(a) and _is_date_literal(b):
+ if isinstance(expression, exp.Predicate):
+ a, b = extract_date(a), extract_date(b)
+ boolean = eval_boolean(expression, a, b)
+ if boolean:
+ return boolean
return None
@@ -590,6 +601,11 @@ def simplify_parens(expression):
return expression
+NONNULL_CONSTANTS = (
+ exp.Literal,
+ exp.Boolean,
+)
+
CONSTANTS = (
exp.Literal,
exp.Boolean,
@@ -597,11 +613,19 @@ CONSTANTS = (
)
+def _is_nonnull_constant(expression: exp.Expression) -> bool:
+ return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
+
+
+def _is_constant(expression: exp.Expression) -> bool:
+ return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
+
+
def simplify_coalesce(expression):
# COALESCE(x) -> x
if (
isinstance(expression, exp.Coalesce)
- and not expression.expressions
+ and (not expression.expressions or _is_nonnull_constant(expression.this))
# COALESCE is also used as a Spark partitioning hint
and not isinstance(expression.parent, exp.Hint)
):
@@ -621,12 +645,12 @@ def simplify_coalesce(expression):
# This transformation is valid for non-constants,
# but it really only does anything if they are both constants.
- if not isinstance(other, CONSTANTS):
+ if not _is_constant(other):
return expression
# Find the first constant arg
for arg_index, arg in enumerate(coalesce.expressions):
- if isinstance(arg, CONSTANTS):
+ if _is_constant(other):
break
else:
return expression
@@ -656,7 +680,6 @@ def simplify_coalesce(expression):
CONCATS = (exp.Concat, exp.DPipe)
-SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
def simplify_concat(expression):
@@ -672,10 +695,15 @@ def simplify_concat(expression):
sep_expr, *expressions = expression.expressions
sep = sep_expr.name
concat_type = exp.ConcatWs
+ args = {}
else:
expressions = expression.expressions
sep = ""
- concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
+ concat_type = exp.Concat
+ args = {
+ "safe": expression.args.get("safe"),
+ "coalesce": expression.args.get("coalesce"),
+ }
new_args = []
for is_string_group, group in itertools.groupby(
@@ -692,7 +720,7 @@ def simplify_concat(expression):
if concat_type is exp.ConcatWs:
new_args = [sep_expr] + new_args
- return concat_type(expressions=new_args)
+ return concat_type(expressions=new_args, **args)
def simplify_conditionals(expression):
@@ -947,7 +975,7 @@ def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.da
def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
if isinstance(cast, exp.Cast):
to = cast.to
- elif isinstance(cast, exp.TsOrDsToDate):
+ elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
to = exp.DataType.build(exp.DataType.Type.DATE)
else:
return None
@@ -966,12 +994,11 @@ def _is_date_literal(expression: exp.Expression) -> bool:
def extract_interval(expression):
- n = int(expression.name)
- unit = expression.text("unit").lower()
-
try:
+ n = int(expression.name)
+ unit = expression.text("unit").lower()
return interval(unit, n)
- except (UnsupportedUnit, ModuleNotFoundError):
+ except (UnsupportedUnit, ModuleNotFoundError, ValueError):
return None
@@ -1099,8 +1126,6 @@ GEN_MAP = {
exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
exp.Div: lambda e: _binary(e, "/"),
exp.Dot: lambda e: _binary(e, "."),
- exp.DPipe: lambda e: _binary(e, "||"),
- exp.SafeDPipe: lambda e: _binary(e, "||"),
exp.EQ: lambda e: _binary(e, "="),
exp.GT: lambda e: _binary(e, ">"),
exp.GTE: lambda e: _binary(e, ">="),