diff options
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r-- | sqlglot/transforms.py | 207 |
1 files changed, 146 insertions, 61 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 70b9a31..ac9dd81 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -146,7 +146,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: if isinstance(unnest, exp.Unnest): alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode expression.args["joins"].remove(join) @@ -163,65 +163,134 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: return expression -def explode_to_unnest(expression: exp.Expression) -> exp.Expression: - """Convert explode/posexplode into unnest (used in hive -> presto).""" - if isinstance(expression, exp.Select): - from sqlglot.optimizer.scope import Scope - - taken_select_names = set(expression.named_selects) - taken_source_names = {name for name, _ in Scope(expression).references} - - for select in expression.selects: - to_replace = select - - pos_alias = "" - explode_alias = "" - - if isinstance(select, exp.Alias): - explode_alias = select.alias - select = select.this - elif isinstance(select, exp.Aliases): - pos_alias = select.aliases[0].name - explode_alias = select.aliases[1].name - select = select.this - - if isinstance(select, (exp.Explode, exp.Posexplode)): - is_posexplode = isinstance(select, exp.Posexplode) - - explode_arg = select.this - unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) +def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: + def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: + """Convert explode/posexplode into unnest (used in hive -> presto).""" + if isinstance(expression, exp.Select): + from sqlglot.optimizer.scope import Scope + + taken_select_names = set(expression.named_selects) + taken_source_names = {name for name, _ in Scope(expression).references} + + def new_name(names: t.Set[str], name: str) -> str: + name = find_new_name(names, name) + names.add(name) + return name + + arrays: t.List[exp.Condition] = [] + series_alias = new_name(taken_select_names, "pos") + series = exp.alias_( + exp.Unnest( + expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] + ), + new_name(taken_source_names, "_u"), + table=[series_alias], + ) + + # we use list here because expression.selects is mutated inside the loop + for select in expression.selects.copy(): + explode = select.find(exp.Explode, exp.Posexplode) + + if isinstance(explode, (exp.Explode, exp.Posexplode)): + pos_alias = "" + explode_alias = "" + + if isinstance(select, exp.Alias): + explode_alias = select.alias + alias = select + elif isinstance(select, exp.Aliases): + pos_alias = select.aliases[0].name + explode_alias = select.aliases[1].name + alias = select.replace(exp.alias_(select.this, "", copy=False)) + else: + alias = select.replace(exp.alias_(select, "")) + explode = alias.find(exp.Explode, exp.Posexplode) + assert explode + + is_posexplode = isinstance(explode, exp.Posexplode) + explode_arg = explode.this + + # This ensures that we won't use [POS]EXPLODE's argument as a new selection + if isinstance(explode_arg, exp.Column): + taken_select_names.add(explode_arg.output_name) + + unnest_source_alias = new_name(taken_source_names, "_u") + + if not explode_alias: + explode_alias = new_name(taken_select_names, "col") + + if is_posexplode: + pos_alias = new_name(taken_select_names, "pos") + + if not pos_alias: + pos_alias = new_name(taken_select_names, "pos") + + alias.set("alias", exp.to_identifier(explode_alias)) + + column = exp.If( + this=exp.column(series_alias).eq(exp.column(pos_alias)), + true=exp.column(explode_alias), + ) - # This ensures that we won't use [POS]EXPLODE's argument as a new selection - if isinstance(explode_arg, exp.Column): - taken_select_names.add(explode_arg.output_name) + explode.replace(column) - unnest_source_alias = find_new_name(taken_source_names, "_u") - taken_source_names.add(unnest_source_alias) + if is_posexplode: + expressions = expression.expressions + expressions.insert( + expressions.index(alias) + 1, + exp.If( + this=exp.column(series_alias).eq(exp.column(pos_alias)), + true=exp.column(pos_alias), + ).as_(pos_alias), + ) + expression.set("expressions", expressions) + + if not arrays: + if expression.args.get("from"): + expression.join(series, copy=False) + else: + expression.from_(series, copy=False) + + size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) + arrays.append(size) + + # trino doesn't support left join unnest with on conditions + # if it did, this would be much simpler + expression.join( + exp.alias_( + exp.Unnest( + expressions=[explode_arg.copy()], + offset=exp.to_identifier(pos_alias), + ), + unnest_source_alias, + table=[explode_alias], + ), + join_type="CROSS", + copy=False, + ) - if not explode_alias: - explode_alias = find_new_name(taken_select_names, "col") - taken_select_names.add(explode_alias) + if index_offset != 1: + size = size - 1 - if is_posexplode: - pos_alias = find_new_name(taken_select_names, "pos") - taken_select_names.add(pos_alias) + expression.where( + exp.column(series_alias) + .eq(exp.column(pos_alias)) + .or_( + (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size)) + ), + copy=False, + ) - if is_posexplode: - column_names = [explode_alias, pos_alias] - to_replace.pop() - expression.select(pos_alias, explode_alias, copy=False) - else: - column_names = [explode_alias] - to_replace.replace(exp.column(explode_alias)) + if arrays: + end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) - unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) + if index_offset != 1: + end = end - (1 - index_offset) + series.expressions[0].set("end", end) - if not expression.args.get("from"): - expression.from_(unnest, copy=False) - else: - expression.join(unnest, join_type="CROSS", copy=False) + return expression - return expression + return _explode_to_unnest PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) @@ -283,6 +352,31 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: return expression +def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Timestamp) and not expression.expression: + return exp.cast( + expression.this, + to=exp.DataType.Type.TIMESTAMP, + ) + return expression + + +def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + for join in expression.args.get("joins") or []: + on = join.args.get("on") + if on and join.kind in ("SEMI", "ANTI"): + subquery = exp.select("1").from_(join.this).where(on) + exists = exp.Exists(this=subquery) + if join.kind == "ANTI": + exists = exists.not_(copy=False) + + join.pop() + expression.where(exists, copy=False) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: @@ -327,12 +421,3 @@ def preprocess( raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") return _to_sql - - -def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Timestamp) and not expression.expression: - return exp.cast( - expression.this, - to=exp.DataType.Type.TIMESTAMP, - ) - return expression |