summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py7
-rw-r--r--sqlglot/optimizer/normalize.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py5
-rw-r--r--sqlglot/optimizer/qualify_tables.py3
-rw-r--r--sqlglot/optimizer/simplify.py35
5 files changed, 33 insertions, 21 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 99888c6..6238759 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -116,6 +116,9 @@ class TypeAnnotator:
exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
expr, exp.DataType.Type.VARCHAR
),
+ exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
+ exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
@@ -335,7 +338,7 @@ class TypeAnnotator:
left_type = expression.left.type.this
right_type = expression.right.type.this
- if isinstance(expression, (exp.And, exp.Or)):
+ if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
@@ -344,7 +347,7 @@ class TypeAnnotator:
)
else:
expression.type = exp.DataType.Type.BOOLEAN
- elif isinstance(expression, (exp.Condition, exp.Predicate)):
+ elif isinstance(expression, exp.Predicate):
expression.type = exp.DataType.Type.BOOLEAN
else:
expression.type = self._maybe_coerce(left_type, right_type)
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index f2df230..40668ef 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -46,7 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
root = node is expression
original = node.copy()
try:
- node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ node = node.replace(
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ )
except OptimizeError as e:
logger.info(e)
node.replace(original)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 6eae2b5..0a31246 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -93,6 +93,7 @@ def _expand_using(scope, resolver):
if column not in columns:
columns[column] = k
+ source_table = ordered[-1]
ordered.append(join_table)
join_columns = resolver.get_source_columns(join_table)
conditions = []
@@ -102,8 +103,10 @@ def _expand_using(scope, resolver):
table = columns.get(identifier)
if not table or identifier not in join_columns:
- raise OptimizeError(f"Cannot automatically join: {identifier}")
+ if columns and join_columns:
+ raise OptimizeError(f"Cannot automatically join: {identifier}")
+ table = table or source_table
conditions.append(
exp.condition(
exp.EQ(
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 93e1179..a719ebe 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -65,5 +65,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None):
if not table_alias.name:
table_alias.set("this", next_name())
+ if isinstance(udtf, exp.Values) and not table_alias.columns:
+ for i, e in enumerate(udtf.expressions[0].expressions):
+ table_alias.append("columns", exp.to_identifier(f"_col_{i}"))
return expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 28ae86d..4e6c910 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -201,23 +201,24 @@ def _simplify_comparison(expression, left, right, or_=False):
return left if (av < bv if or_ else av >= bv) else right
# we can't ever shortcut to true because the column could be null
- if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
- if not or_ and av <= bv:
- return exp.false()
- elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
- if not or_ and av >= bv:
- return exp.false()
- elif isinstance(a, exp.EQ):
- if isinstance(b, exp.LT):
- return exp.false() if av >= bv else a
- if isinstance(b, exp.LTE):
- return exp.false() if av > bv else a
- if isinstance(b, exp.GT):
- return exp.false() if av <= bv else a
- if isinstance(b, exp.GTE):
- return exp.false() if av < bv else a
- if isinstance(b, exp.NEQ):
- return exp.false() if av == bv else a
+ if not or_:
+ if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
+ if av <= bv:
+ return exp.false()
+ elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
+ if av >= bv:
+ return exp.false()
+ elif isinstance(a, exp.EQ):
+ if isinstance(b, exp.LT):
+ return exp.false() if av >= bv else a
+ if isinstance(b, exp.LTE):
+ return exp.false() if av > bv else a
+ if isinstance(b, exp.GT):
+ return exp.false() if av <= bv else a
+ if isinstance(b, exp.GTE):
+ return exp.false() if av < bv else a
+ if isinstance(b, exp.NEQ):
+ return exp.false() if av == bv else a
return None