summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r--sqlglot/optimizer/annotate_types.py26
-rw-r--r--sqlglot/optimizer/canonicalize.py48
-rw-r--r--sqlglot/optimizer/eliminate_joins.py13
-rw-r--r--sqlglot/optimizer/optimize_joins.py4
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py4
-rw-r--r--sqlglot/optimizer/qualify_tables.py14
-rw-r--r--sqlglot/optimizer/simplify.py6
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
10 files changed, 96 insertions, 29 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 96331e2..191ea52 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -245,23 +245,31 @@ class TypeAnnotator:
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
- subscope_selects = {
- name: {select.alias_or_name: select for select in source.selects}
- for name, source in scope.sources.items()
- if isinstance(source, Scope)
- }
-
+ selects = {}
+ for name, source in scope.sources.items():
+ if not isinstance(source, Scope):
+ continue
+ if isinstance(source.expression, exp.Values):
+ selects[name] = {
+ alias: column
+ for alias, column in zip(
+ source.expression.alias_column_names,
+ source.expression.expressions[0].expressions,
+ )
+ }
+ else:
+ selects[name] = {
+ select.alias_or_name: select for select in source.expression.selects
+ }
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
- col.type = subscope_selects[col.table][col.name].type
-
+ col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
-
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
new file mode 100644
index 0000000..9b3d98a
--- /dev/null
+++ b/sqlglot/optimizer/canonicalize.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import exp
+
+
+def canonicalize(expression: exp.Expression) -> exp.Expression:
+ """Converts a sql expression into a standard form.
+
+ This method relies on annotate_types because many of the
+ conversions rely on type inference.
+
+ Args:
+ expression: The expression to canonicalize.
+ """
+ exp.replace_children(expression, canonicalize)
+ expression = add_text_to_concat(expression)
+ expression = coerce_type(expression)
+ return expression
+
+
+def add_text_to_concat(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ node = exp.Concat(this=node.this, expression=node.expression)
+ return node
+
+
+def coerce_type(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Binary):
+ _coerce_date(node.left, node.right)
+ elif isinstance(node, exp.Between):
+ _coerce_date(node.this, node.args["low"])
+ elif isinstance(node, exp.Extract):
+ if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ _replace_cast(node.expression, "datetime")
+ return node
+
+
+def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
+ for a, b in itertools.permutations([a, b]):
+ if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ _replace_cast(b, "date")
+
+
+def _replace_cast(node: exp.Expression, to: str) -> None:
+ data_type = exp.DataType.build(to)
+ cast = exp.Cast(this=node.copy(), to=data_type)
+ cast.type = data_type
+ node.replace(cast)
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 29621af..de4e011 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -128,8 +128,8 @@ def join_condition(join):
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
- on = join.args.get("on") or exp.TRUE
- on = on.copy()
+ on = (join.args.get("on") or exp.true()).copy()
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
@@ -141,7 +141,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
- for condition in on.flatten() if isinstance(on, exp.And) else [on]:
+ for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
@@ -150,13 +150,12 @@ def join_condition(join):
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
- condition.replace(exp.TRUE)
+ condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
- condition.replace(exp.TRUE)
+ condition.replace(exp.true())
on = simplify(on)
- remaining_condition = None if on == exp.TRUE else on
-
+ remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 40e4ab1..fd69832 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -29,7 +29,7 @@ def optimize_joins(expression):
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
join.on(predicate, copy=False)
expression = reorder_joins(expression)
@@ -70,6 +70,6 @@ def normalize(expression):
def other_table_names(join, exclude):
return [
name
- for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index b2ed062..d0e38cd 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,4 +1,6 @@
import sqlglot
+from sqlglot.optimizer.annotate_types import annotate_types
+from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@@ -28,6 +30,8 @@ RULES = (
merge_subqueries,
eliminate_joins,
eliminate_ctes,
+ annotate_types,
+ canonicalize,
quote_identities,
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 6364f65..f92e5c3 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 69fe2b8..e6e6dc9 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -382,9 +382,7 @@ class _Resolver:
raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
- values_alias = source.expression.parent
- if hasattr(values_alias, "alias_column_names"):
- return values_alias.alias_column_names
+ return source.expression.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 0e467d3..5d8e0d9 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,10 +1,11 @@
import itertools
from sqlglot import alias, exp
+from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
-def qualify_tables(expression, db=None, catalog=None):
+def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
@@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
+ schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
"""
@@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
- source.replace(
+ source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
@@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
)
)
+ if schema and isinstance(source.this, exp.ReadCSV):
+ with csv_reader(source.this) as reader:
+ header = next(reader)
+ columns = next(reader)
+ schema.add_table(
+ source, {k: type(v).__name__ for k, v in zip(header, columns)}
+ )
+
return expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d759e86..c432c59 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
# absorb
if is_complement(b, aa):
- aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
- ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
- a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
+ a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index f41a84e..dbd680b 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
select.parent.replace(alias)
for key, column, predicate in keys:
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if key in group_by: