diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 26 | ||||
-rw-r--r-- | sqlglot/optimizer/canonicalize.py | 48 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 13 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 14 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 2 |
10 files changed, 96 insertions, 29 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 96331e2..191ea52 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -245,23 +245,31 @@ class TypeAnnotator: def annotate(self, expression): if isinstance(expression, self.TRAVERSABLES): for scope in traverse_scope(expression): - subscope_selects = { - name: {select.alias_or_name: select for select in source.selects} - for name, source in scope.sources.items() - if isinstance(source, Scope) - } - + selects = {} + for name, source in scope.sources.items(): + if not isinstance(source, Scope): + continue + if isinstance(source.expression, exp.Values): + selects[name] = { + alias: column + for alias, column in zip( + source.expression.alias_column_names, + source.expression.expressions[0].expressions, + ) + } + else: + selects[name] = { + select.alias_or_name: select for select in source.expression.selects + } # First annotate the current scope's column references for col in scope.columns: source = scope.sources[col.table] if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) else: - col.type = subscope_selects[col.table][col.name].type - + col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) - return self._maybe_annotate(expression) # This takes care of non-traversable expressions def _maybe_annotate(self, expression): diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py new file mode 100644 index 0000000..9b3d98a --- /dev/null +++ b/sqlglot/optimizer/canonicalize.py @@ -0,0 +1,48 @@ +import itertools + +from sqlglot import exp + + +def canonicalize(expression: exp.Expression) -> exp.Expression: + """Converts a sql expression into a standard form. + + This method relies on annotate_types because many of the + conversions rely on type inference. + + Args: + expression: The expression to canonicalize. + """ + exp.replace_children(expression, canonicalize) + expression = add_text_to_concat(expression) + expression = coerce_type(expression) + return expression + + +def add_text_to_concat(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES: + node = exp.Concat(this=node.this, expression=node.expression) + return node + + +def coerce_type(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Binary): + _coerce_date(node.left, node.right) + elif isinstance(node, exp.Between): + _coerce_date(node.this, node.args["low"]) + elif isinstance(node, exp.Extract): + if node.expression.type not in exp.DataType.TEMPORAL_TYPES: + _replace_cast(node.expression, "datetime") + return node + + +def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: + for a, b in itertools.permutations([a, b]): + if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE: + _replace_cast(b, "date") + + +def _replace_cast(node: exp.Expression, to: str) -> None: + data_type = exp.DataType.build(to) + cast = exp.Cast(this=node.copy(), to=data_type) + cast.type = data_type + node.replace(cast) diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 29621af..de4e011 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -128,8 +128,8 @@ def join_condition(join): Tuple of (source key, join key, remaining predicate) """ name = join.this.alias_or_name - on = join.args.get("on") or exp.TRUE - on = on.copy() + on = (join.args.get("on") or exp.true()).copy() + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) source_key = [] join_key = [] @@ -141,7 +141,7 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): - for condition in on.flatten() if isinstance(on, exp.And) else [on]: + for condition in on.flatten(): if isinstance(condition, exp.EQ): left, right = condition.unnest_operands() left_tables = exp.column_table_names(left) @@ -150,13 +150,12 @@ def join_condition(join): if name in left_tables and name not in right_tables: join_key.append(left) source_key.append(right) - condition.replace(exp.TRUE) + condition.replace(exp.true()) elif name in right_tables and name not in left_tables: join_key.append(right) source_key.append(left) - condition.replace(exp.TRUE) + condition.replace(exp.true()) on = simplify(on) - remaining_condition = None if on == exp.TRUE else on - + remaining_condition = None if on == exp.true() else on return source_key, join_key, remaining_condition diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 40e4ab1..fd69832 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -29,7 +29,7 @@ def optimize_joins(expression): if isinstance(on, exp.Connector): for predicate in on.flatten(): if name in exp.column_table_names(predicate): - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) join.on(predicate, copy=False) expression = reorder_joins(expression) @@ -70,6 +70,6 @@ def normalize(expression): def other_table_names(join, exclude): return [ name - for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + for name in (exp.column_table_names(join.args.get("on") or exp.true())) if name != exclude ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index b2ed062..d0e38cd 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,4 +1,6 @@ import sqlglot +from sqlglot.optimizer.annotate_types import annotate_types +from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries @@ -28,6 +30,8 @@ RULES = ( merge_subqueries, eliminate_joins, eliminate_ctes, + annotate_types, + canonicalize, quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 6364f65..f92e5c3 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count): for predicate in predicates: for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): if isinstance(node, exp.Join): - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) node.on(predicate, copy=False) break if isinstance(node, exp.Select): - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) node.where(replace_aliases(node, predicate), copy=False) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 69fe2b8..e6e6dc9 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -382,9 +382,7 @@ class _Resolver: raise OptimizeError(str(e)) from e if isinstance(source, Scope) and isinstance(source.expression, exp.Values): - values_alias = source.expression.parent - if hasattr(values_alias, "alias_column_names"): - return values_alias.alias_column_names + return source.expression.alias_column_names # Otherwise, if referencing another scope, return that scope's named selects return source.expression.named_selects diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 0e467d3..5d8e0d9 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,10 +1,11 @@ import itertools from sqlglot import alias, exp +from sqlglot.helper import csv_reader from sqlglot.optimizer.scope import traverse_scope -def qualify_tables(expression, db=None, catalog=None): +def qualify_tables(expression, db=None, catalog=None, schema=None): """ Rewrite sqlglot AST to have fully qualified tables. @@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None): expression (sqlglot.Expression): expression to qualify db (str): Database name catalog (str): Catalog name + schema: A schema to populate Returns: sqlglot.Expression: qualified expression """ @@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None): source.set("catalog", exp.to_identifier(catalog)) if not source.alias: - source.replace( + source = source.replace( alias( source.copy(), source.this if identifier else f"_q_{next(sequence)}", @@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None): ) ) + if schema and isinstance(source.this, exp.ReadCSV): + with csv_reader(source.this) as reader: + header = next(reader) + columns = next(reader) + schema.add_table( + source, {k: type(v).__name__ for k, v in zip(header, columns)} + ) + return expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d759e86..c432c59 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -189,11 +189,11 @@ def absorb_and_eliminate(expression): # absorb if is_complement(b, aa): - aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) + aa.replace(exp.true() if kind == exp.And else exp.false()) elif is_complement(b, ab): - ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) + ab.replace(exp.true() if kind == exp.And else exp.false()) elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): - a.replace(exp.FALSE if kind == exp.And else exp.TRUE) + a.replace(exp.false() if kind == exp.And else exp.true()) elif isinstance(b, kind): # eliminate rhs = b.unnest_operands() diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index f41a84e..dbd680b 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence): select.parent.replace(alias) for key, column, predicate in keys: - predicate.replace(exp.TRUE) + predicate.replace(exp.true()) nested = exp.column(key_aliases[key], table_alias) if key in group_by: |