diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 131 |
1 files changed, 104 insertions, 27 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 30055bc..96331e2 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_list, subclasses +from sqlglot.helper import ensure_collection, ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -48,35 +48,65 @@ class TypeAnnotator: exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BIGINT + ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.CurrentTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATETIME + ), exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.TimestampSub: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), + exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), + exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), + exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.GroupConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.ArrayConcat: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -88,32 +118,52 @@ class TypeAnnotator: exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), + exp.RegexpLike: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.BOOLEAN + ), exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.StrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DATE + ), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.TIMESTAMP + ), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.VARCHAR + ), exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type( + expr, exp.DataType.Type.DOUBLE + ), exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } @@ -124,7 +174,11 @@ class TypeAnnotator: exp.DataType.Type.TEXT: set(), exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, + exp.DataType.Type.NCHAR: { + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.TEXT, + }, exp.DataType.Type.CHAR: { exp.DataType.Type.NCHAR, exp.DataType.Type.VARCHAR, @@ -135,7 +189,11 @@ class TypeAnnotator: exp.DataType.Type.DOUBLE: set(), exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, + exp.DataType.Type.BIGINT: { + exp.DataType.Type.DECIMAL, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DOUBLE, + }, exp.DataType.Type.INT: { exp.DataType.Type.BIGINT, exp.DataType.Type.DECIMAL, @@ -160,7 +218,10 @@ class TypeAnnotator: # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ exp.DataType.Type.TIMESTAMPLTZ: set(), exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ}, + exp.DataType.Type.TIMESTAMP: { + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + }, exp.DataType.Type.DATETIME: { exp.DataType.Type.TIMESTAMP, exp.DataType.Type.TIMESTAMPTZ, @@ -219,7 +280,7 @@ class TypeAnnotator: def _annotate_args(self, expression): for value in expression.args.values(): - for v in ensure_list(value): + for v in ensure_collection(value): self._maybe_annotate(v) return expression @@ -243,7 +304,9 @@ class TypeAnnotator: if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: expression.type = exp.DataType.Type.NULL elif exp.DataType.Type.NULL in (left_type, right_type): - expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + expression.type = exp.DataType.build( + "NULLABLE", expressions=exp.DataType.build("BOOLEAN") + ) else: expression.type = exp.DataType.Type.BOOLEAN elif isinstance(expression, (exp.Condition, exp.Predicate)): @@ -276,3 +339,17 @@ class TypeAnnotator: def _annotate_with_type(self, expression, target_type): expression.type = target_type return self._annotate_args(expression) + + def _annotate_by_args(self, expression, *args): + self._annotate_args(expression) + expressions = [] + for arg in args: + arg_expr = expression.args.get(arg) + expressions.extend(expr for expr in ensure_list(arg_expr) if expr) + + last_datatype = None + for expr in expressions: + last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) + + expression.type = last_datatype or exp.DataType.Type.UNKNOWN + return expression |