summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py158
1 files changed, 137 insertions, 21 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 3f5f089..a2cef37 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,16 +1,20 @@
from sqlglot import exp
from sqlglot.helper import ensure_list, subclasses
+from sqlglot.optimizer.schema import ensure_schema
+from sqlglot.optimizer.scope import Scope, traverse_scope
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
"""
Recursively infer & annotate types in an expression syntax tree against a schema.
+ Assumes that we've already executed the optimizer's qualify_columns step.
- (TODO -- replace this with a better example after adding some functionality)
Example:
>>> import sqlglot
- >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
- >>> annotated_expression.type
+ >>> schema = {"y": {"cola": "SMALLINT"}}
+ >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
+ >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
+ >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Args:
@@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
sqlglot.Expression: expression annotated with types
"""
+ schema = ensure_schema(schema)
+
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
@@ -35,10 +41,81 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
- exp.Cast: lambda self, expr: self._annotate_cast(expr),
- exp.DataType: lambda self, expr: self._annotate_data_type(expr),
+ exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
+ exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
+ exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
- exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
+ exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
+ exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
+ exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
+ exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
+ exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
+ exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
+ exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
+ exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
+ exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
+ exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
+ exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
# Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
@@ -97,43 +174,82 @@ class TypeAnnotator:
},
}
+ TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
+
def __init__(self, schema=None, annotators=None, coerces_to=None):
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
+ if isinstance(expression, self.TRAVERSABLES):
+ for scope in traverse_scope(expression):
+ subscope_selects = {
+ name: {select.alias_or_name: select for select in source.selects}
+ for name, source in scope.sources.items()
+ if isinstance(source, Scope)
+ }
+
+ # First annotate the current scope's column references
+ for col in scope.columns:
+ source = scope.sources[col.table]
+ if isinstance(source, exp.Table):
+ col.type = self.schema.get_column_type(source, col)
+ else:
+ col.type = subscope_selects[col.table][col.name].type
+
+ # Then (possibly) annotate the remaining expressions in the scope
+ self._maybe_annotate(scope.expression)
+
+ return self._maybe_annotate(expression) # This takes care of non-traversable expressions
+
+ def _maybe_annotate(self, expression):
if not isinstance(expression, exp.Expression):
return None
+ if expression.type:
+ return expression # We've already inferred the expression's type
+
annotator = self.annotators.get(expression.__class__)
- return annotator(self, expression) if annotator else self._annotate_args(expression)
+ return (
+ annotator(self, expression)
+ if annotator
+ else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
+ )
def _annotate_args(self, expression):
for value in expression.args.values():
for v in ensure_list(value):
- self.annotate(v)
+ self._maybe_annotate(v)
return expression
- def _annotate_cast(self, expression):
- expression.type = expression.args["to"].this
- return self._annotate_args(expression)
-
- def _annotate_data_type(self, expression):
- expression.type = expression.this
- return self._annotate_args(expression)
-
def _maybe_coerce(self, type1, type2):
+ # We propagate the NULL / UNKNOWN types upwards if found
+ if exp.DataType.Type.NULL in (type1, type2):
+ return exp.DataType.Type.NULL
+ if exp.DataType.Type.UNKNOWN in (type1, type2):
+ return exp.DataType.Type.UNKNOWN
+
return type2 if type2 in self.coerces_to[type1] else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
- if isinstance(expression, (exp.Condition, exp.Predicate)):
+ left_type = expression.left.type
+ right_type = expression.right.type
+
+ if isinstance(expression, (exp.And, exp.Or)):
+ if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
+ expression.type = exp.DataType.Type.NULL
+ elif exp.DataType.Type.NULL in (left_type, right_type):
+ expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
+ else:
+ expression.type = exp.DataType.Type.BOOLEAN
+ elif isinstance(expression, (exp.Condition, exp.Predicate)):
expression.type = exp.DataType.Type.BOOLEAN
else:
- expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
+ expression.type = self._maybe_coerce(left_type, right_type)
return expression
@@ -157,6 +273,6 @@ class TypeAnnotator:
return expression
- def _annotate_boolean(self, expression):
- expression.type = exp.DataType.Type.BOOLEAN
- return expression
+ def _annotate_with_type(self, expression, target_type):
+ expression.type = target_type
+ return self._annotate_args(expression)