summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py37
1 files changed, 28 insertions, 9 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 191ea52..be17f15 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
- >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
+ >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Args:
@@ -41,9 +41,12 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
- exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
+ exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
+ exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
+ exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
+ exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
@@ -52,6 +55,9 @@ class TypeAnnotator:
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
+ exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
+ exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
@@ -263,10 +269,10 @@ class TypeAnnotator:
}
# First annotate the current scope's column references
for col in scope.columns:
- source = scope.sources[col.table]
+ source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
- else:
+ elif source:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@@ -280,6 +286,7 @@ class TypeAnnotator:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__)
+
return (
annotator(self, expression)
if annotator
@@ -295,18 +302,23 @@ class TypeAnnotator:
def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
+ if isinstance(type1, exp.DataType):
+ type1 = type1.this
+ if isinstance(type2, exp.DataType):
+ type2 = type2.this
+
if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to[type1] else type1
+ return type2 if type2 in self.coerces_to.get(type1, {}) else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
- left_type = expression.left.type
- right_type = expression.right.type
+ left_type = expression.left.type.this
+ right_type = expression.right.type.this
if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@@ -348,7 +360,7 @@ class TypeAnnotator:
expression.type = target_type
return self._annotate_args(expression)
- def _annotate_by_args(self, expression, *args):
+ def _annotate_by_args(self, expression, *args, promote=False):
self._annotate_args(expression)
expressions = []
for arg in args:
@@ -360,4 +372,11 @@ class TypeAnnotator:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
+
+ if promote:
+ if expression.type.this in exp.DataType.INTEGER_TYPES:
+ expression.type = exp.DataType.Type.BIGINT
+ elif expression.type.this in exp.DataType.FLOAT_TYPES:
+ expression.type = exp.DataType.Type.DOUBLE
+
return expression