summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r--sqlglot/transforms.py70
1 files changed, 47 insertions, 23 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 00f278e..3643cd7 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -121,20 +121,9 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
other expressions. This transforms removes the precision from parameterized types in expressions.
"""
- return expression.transform(
- lambda node: exp.DataType(
- **{
- **node.args,
- "expressions": [
- node_expression
- for node_expression in node.expressions
- if isinstance(node_expression, exp.DataType)
- ],
- }
- )
- if isinstance(node, exp.DataType)
- else node,
- )
+ for node in expression.find_all(exp.DataType):
+ node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
+ return expression
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
@@ -240,12 +229,36 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
return expression
+def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression:
+ if (
+ isinstance(expression, exp.WithinGroup)
+ and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc))
+ and isinstance(expression.expression, exp.Order)
+ ):
+ quantile = expression.this.this
+ input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
+ return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
+
+ return expression
+
+
+def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.Pivot):
+ expression.args["field"].transform(
+ lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
+ copy=False,
+ )
+
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
- expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
+ expression to SQL, using either the "_sql" method corresponding to the resulting expression,
+ or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Args:
transforms: sequence of transform functions. These will be called in order.
@@ -255,17 +268,28 @@ def preprocess(
"""
def _to_sql(self, expression: exp.Expression) -> str:
+ expression_type = type(expression)
+
expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
- return getattr(self, expression.key + "_sql")(expression)
- return _to_sql
+ _sql_handler = getattr(self, expression.key + "_sql", None)
+ if _sql_handler:
+ return _sql_handler(expression)
+
+ transforms_handler = self.TRANSFORMS.get(type(expression))
+ if transforms_handler:
+ # Ensures we don't enter an infinite loop. This can happen when the original expression
+ # has the same type as the final expression and there's no _sql method available for it,
+ # because then it'd re-enter _to_sql.
+ if expression_type is type(expression):
+ raise ValueError(
+ f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
+ )
+ return transforms_handler(self, expression)
-UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])}
-ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])}
-ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])}
-REMOVE_PRECISION_PARAMETERIZED_TYPES = {
- exp.Cast: preprocess([remove_precision_parameterized_types])
-}
+ raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
+
+ return _to_sql