diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/transforms.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 66ab884..70b9a31 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -224,10 +224,27 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression: return expression +PERCENTILES = (exp.PercentileCont, exp.PercentileDisc) + + +def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, PERCENTILES) + and not isinstance(expression.parent, exp.WithinGroup) + and expression.expression + ): + column = expression.this.pop() + expression.set("this", expression.expression.pop()) + order = exp.Order(expressions=[exp.Ordered(this=column)]) + expression = exp.WithinGroup(this=expression, expression=order) + + return expression + + def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: if ( isinstance(expression, exp.WithinGroup) - and isinstance(expression.this, (exp.PercentileCont, exp.PercentileDisc)) + and isinstance(expression.this, PERCENTILES) and isinstance(expression.expression, exp.Order) ): quantile = expression.this.this @@ -294,10 +311,13 @@ def preprocess( transforms_handler = self.TRANSFORMS.get(type(expression)) if transforms_handler: - # Ensures we don't enter an infinite loop. This can happen when the original expression - # has the same type as the final expression and there's no _sql method available for it, - # because then it'd re-enter _to_sql. if expression_type is type(expression): + if isinstance(expression, exp.Func): + return self.function_fallback_sql(expression) + + # Ensures we don't enter an infinite loop. This can happen when the original expression + # has the same type as the final expression and there's no _sql method available for it, + # because then it'd re-enter _to_sql. raise ValueError( f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." ) @@ -307,3 +327,12 @@ def preprocess( raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") return _to_sql + + +def timestamp_to_cast(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Timestamp) and not expression.expression: + return exp.cast( + expression.this, + to=exp.DataType.Type.TIMESTAMP, + ) + return expression |