summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py1
-rw-r--r--sqlglot/optimizer/qualify_columns.py19
-rw-r--r--sqlglot/optimizer/simplify.py26
3 files changed, 26 insertions, 20 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 0fc5f4c..e7cb80b 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -105,6 +105,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.CurrentDate,
exp.Date,
exp.DateAdd,
+ exp.DateFromParts,
exp.DateStrToDate,
exp.DateSub,
exp.DateTrunc,
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 8c3f599..435585c 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -220,25 +220,6 @@ def _expand_group_by(scope: Scope):
group.set("expressions", _expand_positional_references(scope, group.expressions))
expression.set("group", group)
- # group by expressions cannot be simplified, for example
- # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
- # the projection must exactly match the group by key
- groups = set(group.expressions)
- group.meta["final"] = True
-
- for e in expression.selects:
- for node, *_ in e.walk():
- if node in groups:
- e.meta["final"] = True
- break
-
- having = expression.args.get("having")
- if having:
- for node, *_ in having.walk():
- if node in groups:
- having.meta["final"] = True
- break
-
def _expand_order_by(scope: Scope, resolver: Resolver):
order = scope.expression.args.get("order")
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 1a2d82c..e247f58 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -8,6 +8,9 @@ from sqlglot import exp
from sqlglot.generator import cached_generator
from sqlglot.helper import first, while_changing
+# Final means that an expression should not be simplified
+FINAL = "final"
+
def simplify(expression):
"""
@@ -27,8 +30,29 @@ def simplify(expression):
generate = cached_generator()
+ # group by expressions cannot be simplified, for example
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
+ # the projection must exactly match the group by key
+ for group in expression.find_all(exp.Group):
+ select = group.parent
+ groups = set(group.expressions)
+ group.meta[FINAL] = True
+
+ for e in select.selects:
+ for node, *_ in e.walk():
+ if node in groups:
+ e.meta[FINAL] = True
+ break
+
+ having = select.args.get("having")
+ if having:
+ for node, *_ in having.walk():
+ if node in groups:
+ having.meta[FINAL] = True
+ break
+
def _simplify(expression, root=True):
- if expression.meta.get("final"):
+ if expression.meta.get(FINAL):
return expression
node = expression
node = rewrite_between(node)