summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r--sqlglot/transforms.py207
1 files changed, 146 insertions, 61 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 70b9a31..ac9dd81 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -146,7 +146,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
- udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
+ udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
expression.args["joins"].remove(join)
@@ -163,65 +163,134 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
return expression
-def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
- """Convert explode/posexplode into unnest (used in hive -> presto)."""
- if isinstance(expression, exp.Select):
- from sqlglot.optimizer.scope import Scope
-
- taken_select_names = set(expression.named_selects)
- taken_source_names = {name for name, _ in Scope(expression).references}
-
- for select in expression.selects:
- to_replace = select
-
- pos_alias = ""
- explode_alias = ""
-
- if isinstance(select, exp.Alias):
- explode_alias = select.alias
- select = select.this
- elif isinstance(select, exp.Aliases):
- pos_alias = select.aliases[0].name
- explode_alias = select.aliases[1].name
- select = select.this
-
- if isinstance(select, (exp.Explode, exp.Posexplode)):
- is_posexplode = isinstance(select, exp.Posexplode)
-
- explode_arg = select.this
- unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode)
+def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]:
+ def _explode_to_unnest(expression: exp.Expression) -> exp.Expression:
+ """Convert explode/posexplode into unnest (used in hive -> presto)."""
+ if isinstance(expression, exp.Select):
+ from sqlglot.optimizer.scope import Scope
+
+ taken_select_names = set(expression.named_selects)
+ taken_source_names = {name for name, _ in Scope(expression).references}
+
+ def new_name(names: t.Set[str], name: str) -> str:
+ name = find_new_name(names, name)
+ names.add(name)
+ return name
+
+ arrays: t.List[exp.Condition] = []
+ series_alias = new_name(taken_select_names, "pos")
+ series = exp.alias_(
+ exp.Unnest(
+ expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
+ ),
+ new_name(taken_source_names, "_u"),
+ table=[series_alias],
+ )
+
+ # we use list here because expression.selects is mutated inside the loop
+ for select in expression.selects.copy():
+ explode = select.find(exp.Explode, exp.Posexplode)
+
+ if isinstance(explode, (exp.Explode, exp.Posexplode)):
+ pos_alias = ""
+ explode_alias = ""
+
+ if isinstance(select, exp.Alias):
+ explode_alias = select.alias
+ alias = select
+ elif isinstance(select, exp.Aliases):
+ pos_alias = select.aliases[0].name
+ explode_alias = select.aliases[1].name
+ alias = select.replace(exp.alias_(select.this, "", copy=False))
+ else:
+ alias = select.replace(exp.alias_(select, ""))
+ explode = alias.find(exp.Explode, exp.Posexplode)
+ assert explode
+
+ is_posexplode = isinstance(explode, exp.Posexplode)
+ explode_arg = explode.this
+
+ # This ensures that we won't use [POS]EXPLODE's argument as a new selection
+ if isinstance(explode_arg, exp.Column):
+ taken_select_names.add(explode_arg.output_name)
+
+ unnest_source_alias = new_name(taken_source_names, "_u")
+
+ if not explode_alias:
+ explode_alias = new_name(taken_select_names, "col")
+
+ if is_posexplode:
+ pos_alias = new_name(taken_select_names, "pos")
+
+ if not pos_alias:
+ pos_alias = new_name(taken_select_names, "pos")
+
+ alias.set("alias", exp.to_identifier(explode_alias))
+
+ column = exp.If(
+ this=exp.column(series_alias).eq(exp.column(pos_alias)),
+ true=exp.column(explode_alias),
+ )
- # This ensures that we won't use [POS]EXPLODE's argument as a new selection
- if isinstance(explode_arg, exp.Column):
- taken_select_names.add(explode_arg.output_name)
+ explode.replace(column)
- unnest_source_alias = find_new_name(taken_source_names, "_u")
- taken_source_names.add(unnest_source_alias)
+ if is_posexplode:
+ expressions = expression.expressions
+ expressions.insert(
+ expressions.index(alias) + 1,
+ exp.If(
+ this=exp.column(series_alias).eq(exp.column(pos_alias)),
+ true=exp.column(pos_alias),
+ ).as_(pos_alias),
+ )
+ expression.set("expressions", expressions)
+
+ if not arrays:
+ if expression.args.get("from"):
+ expression.join(series, copy=False)
+ else:
+ expression.from_(series, copy=False)
+
+ size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
+ arrays.append(size)
+
+ # trino doesn't support left join unnest with on conditions
+ # if it did, this would be much simpler
+ expression.join(
+ exp.alias_(
+ exp.Unnest(
+ expressions=[explode_arg.copy()],
+ offset=exp.to_identifier(pos_alias),
+ ),
+ unnest_source_alias,
+ table=[explode_alias],
+ ),
+ join_type="CROSS",
+ copy=False,
+ )
- if not explode_alias:
- explode_alias = find_new_name(taken_select_names, "col")
- taken_select_names.add(explode_alias)
+ if index_offset != 1:
+ size = size - 1
- if is_posexplode:
- pos_alias = find_new_name(taken_select_names, "pos")
- taken_select_names.add(pos_alias)
+ expression.where(
+ exp.column(series_alias)
+ .eq(exp.column(pos_alias))
+ .or_(
+ (exp.column(series_alias) > size).and_(exp.column(pos_alias).eq(size))
+ ),
+ copy=False,
+ )
- if is_posexplode:
- column_names = [explode_alias, pos_alias]
- to_replace.pop()
- expression.select(pos_alias, explode_alias, copy=False)
- else:
- column_names = [explode_alias]
- to_replace.replace(exp.column(explode_alias))
+ if arrays:
+ end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
- unnest = exp.alias_(unnest, unnest_source_alias, table=column_names)
+ if index_offset != 1:
+ end = end - (1 - index_offset)
+ series.expressions[0].set("end", end)
- if not expression.args.get("from"):
- expression.from_(unnest, copy=False)
- else:
- expression.join(unnest, join_type="CROSS", copy=False)
+ return expression
- return expression
+ return _explode_to_unnest
PERCENTILES = (exp.PercentileCont, exp.PercentileDisc)
@@ -283,6 +352,31 @@ def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression:
return expression
+def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Timestamp) and not expression.expression:
+ return exp.cast(
+ expression.this,
+ to=exp.DataType.Type.TIMESTAMP,
+ )
+ return expression
+
+
+def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Select):
+ for join in expression.args.get("joins") or []:
+ on = join.args.get("on")
+ if on and join.kind in ("SEMI", "ANTI"):
+ subquery = exp.select("1").from_(join.this).where(on)
+ exists = exp.Exists(this=subquery)
+ if join.kind == "ANTI":
+ exists = exists.not_(copy=False)
+
+ join.pop()
+ expression.where(exists, copy=False)
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
@@ -327,12 +421,3 @@ def preprocess(
raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
return _to_sql
-
-
-def timestamp_to_cast(expression: exp.Expression) -> exp.Expression:
- if isinstance(expression, exp.Timestamp) and not expression.expression:
- return exp.cast(
- expression.this,
- to=exp.DataType.Type.TIMESTAMP,
- )
- return expression