diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/transforms.py | 62 |
1 files changed, 49 insertions, 13 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 445fda6..03acc2b 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -141,7 +141,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr def unnest_to_explode(expression: exp.Expression) -> exp.Expression: - """Convert cross join unnest into lateral view explode (used in presto -> hive).""" + """Convert cross join unnest into lateral view explode.""" if isinstance(expression, exp.Select): for join in expression.args.get("joins") or []: unnest = join.this @@ -166,7 +166,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: - """Convert explode/posexplode into unnest (used in hive -> presto).""" + """Convert explode/posexplode into unnest.""" def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Select): @@ -199,11 +199,11 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp explode_alias = "" if isinstance(select, exp.Alias): - explode_alias = select.alias + explode_alias = select.args["alias"] alias = select elif isinstance(select, exp.Aliases): - pos_alias = select.aliases[0].name - explode_alias = select.aliases[1].name + pos_alias = select.aliases[0] + explode_alias = select.aliases[1] alias = select.replace(exp.alias_(select.this, "", copy=False)) else: alias = select.replace(exp.alias_(select, "")) @@ -230,9 +230,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp alias.set("alias", exp.to_identifier(explode_alias)) + series_table_alias = series.args["alias"].this column = exp.If( - this=exp.column(series_alias).eq(exp.column(pos_alias)), - true=exp.column(explode_alias), + this=exp.column(series_alias, table=series_table_alias).eq( + exp.column(pos_alias, table=unnest_source_alias) + ), + true=exp.column(explode_alias, table=unnest_source_alias), ) explode.replace(column) @@ -242,8 +245,10 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp expressions.insert( expressions.index(alias) + 1, exp.If( - this=exp.column(series_alias).eq(exp.column(pos_alias)), - true=exp.column(pos_alias), + this=exp.column(series_alias, table=series_table_alias).eq( + exp.column(pos_alias, table=unnest_source_alias) + ), + true=exp.column(pos_alias, table=unnest_source_alias), ).as_(pos_alias), ) expression.set("expressions", expressions) @@ -276,10 +281,12 @@ def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp size = size - 1 expression.where( - exp.column(series_alias) - .eq(exp.column(pos_alias)) + exp.column(series_alias, table=series_table_alias) + .eq(exp.column(pos_alias, table=unnest_source_alias)) .or_( - (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) + (exp.column(series_alias, table=series_table_alias) > size).and_( + exp.column(pos_alias, table=unnest_source_alias).eq(size) + ) ), copy=False, ) @@ -386,14 +393,16 @@ def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: full_outer_joins = [ (index, join) for index, join in enumerate(expression.args.get("joins") or []) - if join.side == "FULL" and join.kind == "OUTER" + if join.side == "FULL" ] if len(full_outer_joins) == 1: expression_copy = expression.copy() + expression.set("limit", None) index, full_outer_join = full_outer_joins[0] full_outer_join.set("side", "left") expression_copy.args["joins"][index].set("side", "right") + expression_copy.args.pop("with", None) # remove CTEs from RIGHT side return exp.union(expression, expression_copy, copy=False) @@ -430,6 +439,33 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: return expression +def ensure_bools(expression: exp.Expression) -> exp.Expression: + """Converts numeric values used in conditions into explicit boolean expressions.""" + from sqlglot.optimizer.canonicalize import ensure_bools + + def _ensure_bool(node: exp.Expression) -> None: + if ( + node.is_number + or node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) + or (isinstance(node, exp.Column) and not node.type) + ): + node.replace(node.neq(0)) + + for node, *_ in expression.walk(): + ensure_bools(node, _ensure_bool) + + return expression + + +def unqualify_columns(expression: exp.Expression) -> exp.Expression: + for column in expression.find_all(exp.Column): + # We only wanna pop off the table, db, catalog args + for part in column.parts[:-1]: + part.pop() + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: |