summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py131
1 files changed, 104 insertions, 27 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 30055bc..96331e2 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,5 @@
from sqlglot import exp
-from sqlglot.helper import ensure_list, subclasses
+from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -48,35 +48,65 @@ class TypeAnnotator:
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
- exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.BIGINT
+ ),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
- exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
+ exp.CurrentTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
- exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
+ exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
+ exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.TimestampSub: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATE
+ ),
+ exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
+ exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
+ exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
+ exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
+ exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.GroupConcat: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
+ exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@@ -88,32 +118,52 @@ class TypeAnnotator:
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DOUBLE
+ ),
+ exp.RegexpLike: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.BOOLEAN
+ ),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.StrToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
+ exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATE
+ ),
+ exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.UnixToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.VariancePop: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DOUBLE
+ ),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
@@ -124,7 +174,11 @@ class TypeAnnotator:
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
- exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
+ exp.DataType.Type.NCHAR: {
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NVARCHAR,
+ exp.DataType.Type.TEXT,
+ },
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
@@ -135,7 +189,11 @@ class TypeAnnotator:
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
- exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
+ exp.DataType.Type.BIGINT: {
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DOUBLE,
+ },
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
@@ -160,7 +218,10 @@ class TypeAnnotator:
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
- exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
+ exp.DataType.Type.TIMESTAMP: {
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ },
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
@@ -219,7 +280,7 @@ class TypeAnnotator:
def _annotate_args(self, expression):
for value in expression.args.values():
- for v in ensure_list(value):
+ for v in ensure_collection(value):
self._maybe_annotate(v)
return expression
@@ -243,7 +304,9 @@ class TypeAnnotator:
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
- expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
+ expression.type = exp.DataType.build(
+ "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
+ )
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
@@ -276,3 +339,17 @@ class TypeAnnotator:
def _annotate_with_type(self, expression, target_type):
expression.type = target_type
return self._annotate_args(expression)
+
+ def _annotate_by_args(self, expression, *args):
+ self._annotate_args(expression)
+ expressions = []
+ for arg in args:
+ arg_expr = expression.args.get(arg)
+ expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
+
+ last_datatype = None
+ for expr in expressions:
+ last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
+
+ expression.type = last_datatype or exp.DataType.Type.UNKNOWN
+ return expression