sqlglot.optimizer.annotate_types
1from sqlglot import exp 2from sqlglot.helper import ensure_list, subclasses 3from sqlglot.optimizer.scope import Scope, traverse_scope 4from sqlglot.schema import ensure_schema 5 6 7def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 8 """ 9 Recursively infer & annotate types in an expression syntax tree against a schema. 10 Assumes that we've already executed the optimizer's qualify_columns step. 11 12 Example: 13 >>> import sqlglot 14 >>> schema = {"y": {"cola": "SMALLINT"}} 15 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 16 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 17 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 18 <Type.DOUBLE: 'DOUBLE'> 19 20 Args: 21 expression (sqlglot.Expression): Expression to annotate. 22 schema (dict|sqlglot.optimizer.Schema): Database schema. 23 annotators (dict): Maps expression type to corresponding annotation function. 24 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 25 Returns: 26 sqlglot.Expression: expression annotated with types 27 """ 28 29 schema = ensure_schema(schema) 30 31 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 32 33 34class TypeAnnotator: 35 ANNOTATORS = { 36 **{ 37 expr_type: lambda self, expr: self._annotate_unary(expr) 38 for expr_type in subclasses(exp.__name__, exp.Unary) 39 }, 40 **{ 41 expr_type: lambda self, expr: self._annotate_binary(expr) 42 for expr_type in subclasses(exp.__name__, exp.Binary) 43 }, 44 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 45 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 47 exp.Alias: lambda self, expr: self._annotate_unary(expr), 48 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 49 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.Literal: lambda self, expr: self._annotate_literal(expr), 51 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 52 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 53 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 54 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 55 expr, exp.DataType.Type.BIGINT 56 ), 57 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 58 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 59 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Sum: lambda self, expr: self._annotate_by_args( 61 expr, "this", "expressions", promote=True 62 ), 63 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 64 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 65 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 66 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 67 expr, exp.DataType.Type.DATETIME 68 ), 69 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 70 expr, exp.DataType.Type.TIMESTAMP 71 ), 72 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 73 expr, exp.DataType.Type.TIMESTAMP 74 ), 75 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 76 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 78 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 79 expr, exp.DataType.Type.DATETIME 80 ), 81 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 82 expr, exp.DataType.Type.DATETIME 83 ), 84 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 85 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 87 expr, exp.DataType.Type.TIMESTAMP 88 ), 89 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 90 expr, exp.DataType.Type.TIMESTAMP 91 ), 92 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 93 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 94 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 96 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 97 expr, exp.DataType.Type.DATE 98 ), 99 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 100 expr, exp.DataType.Type.VARCHAR 101 ), 102 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 103 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 104 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 105 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 106 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 107 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 108 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 109 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 110 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 111 exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 112 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 113 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 114 expr, exp.DataType.Type.VARCHAR 115 ), 116 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 117 expr, exp.DataType.Type.VARCHAR 118 ), 119 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 120 exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), 121 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 122 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 123 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 124 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 125 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 126 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 127 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 128 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 129 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 130 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 131 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 132 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 133 expr, exp.DataType.Type.DOUBLE 134 ), 135 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 136 expr, exp.DataType.Type.BOOLEAN 137 ), 138 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 139 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 140 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 141 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 142 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 143 exp.StrToTime: lambda self, expr: self._annotate_with_type( 144 expr, exp.DataType.Type.TIMESTAMP 145 ), 146 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 147 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 148 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 149 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 150 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 151 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 152 expr, exp.DataType.Type.VARCHAR 153 ), 154 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 155 expr, exp.DataType.Type.DATE 156 ), 157 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 158 expr, exp.DataType.Type.TIMESTAMP 159 ), 160 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 161 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 162 expr, exp.DataType.Type.VARCHAR 163 ), 164 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 165 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 166 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 167 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 168 expr, exp.DataType.Type.TIMESTAMP 169 ), 170 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 171 expr, exp.DataType.Type.VARCHAR 172 ), 173 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 174 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 175 exp.VariancePop: lambda self, expr: self._annotate_with_type( 176 expr, exp.DataType.Type.DOUBLE 177 ), 178 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 179 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 180 } 181 182 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 183 COERCES_TO = { 184 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 185 exp.DataType.Type.TEXT: set(), 186 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 187 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 188 exp.DataType.Type.NCHAR: { 189 exp.DataType.Type.VARCHAR, 190 exp.DataType.Type.NVARCHAR, 191 exp.DataType.Type.TEXT, 192 }, 193 exp.DataType.Type.CHAR: { 194 exp.DataType.Type.NCHAR, 195 exp.DataType.Type.VARCHAR, 196 exp.DataType.Type.NVARCHAR, 197 exp.DataType.Type.TEXT, 198 }, 199 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 200 exp.DataType.Type.DOUBLE: set(), 201 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 202 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 203 exp.DataType.Type.BIGINT: { 204 exp.DataType.Type.DECIMAL, 205 exp.DataType.Type.FLOAT, 206 exp.DataType.Type.DOUBLE, 207 }, 208 exp.DataType.Type.INT: { 209 exp.DataType.Type.BIGINT, 210 exp.DataType.Type.DECIMAL, 211 exp.DataType.Type.FLOAT, 212 exp.DataType.Type.DOUBLE, 213 }, 214 exp.DataType.Type.SMALLINT: { 215 exp.DataType.Type.INT, 216 exp.DataType.Type.BIGINT, 217 exp.DataType.Type.DECIMAL, 218 exp.DataType.Type.FLOAT, 219 exp.DataType.Type.DOUBLE, 220 }, 221 exp.DataType.Type.TINYINT: { 222 exp.DataType.Type.SMALLINT, 223 exp.DataType.Type.INT, 224 exp.DataType.Type.BIGINT, 225 exp.DataType.Type.DECIMAL, 226 exp.DataType.Type.FLOAT, 227 exp.DataType.Type.DOUBLE, 228 }, 229 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 230 exp.DataType.Type.TIMESTAMPLTZ: set(), 231 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 232 exp.DataType.Type.TIMESTAMP: { 233 exp.DataType.Type.TIMESTAMPTZ, 234 exp.DataType.Type.TIMESTAMPLTZ, 235 }, 236 exp.DataType.Type.DATETIME: { 237 exp.DataType.Type.TIMESTAMP, 238 exp.DataType.Type.TIMESTAMPTZ, 239 exp.DataType.Type.TIMESTAMPLTZ, 240 }, 241 exp.DataType.Type.DATE: { 242 exp.DataType.Type.DATETIME, 243 exp.DataType.Type.TIMESTAMP, 244 exp.DataType.Type.TIMESTAMPTZ, 245 exp.DataType.Type.TIMESTAMPLTZ, 246 }, 247 } 248 249 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 250 251 def __init__(self, schema=None, annotators=None, coerces_to=None): 252 self.schema = schema 253 self.annotators = annotators or self.ANNOTATORS 254 self.coerces_to = coerces_to or self.COERCES_TO 255 256 def annotate(self, expression): 257 if isinstance(expression, self.TRAVERSABLES): 258 for scope in traverse_scope(expression): 259 selects = {} 260 for name, source in scope.sources.items(): 261 if not isinstance(source, Scope): 262 continue 263 if isinstance(source.expression, exp.UDTF): 264 values = [] 265 266 if isinstance(source.expression, exp.Lateral): 267 if isinstance(source.expression.this, exp.Explode): 268 values = [source.expression.this.this] 269 else: 270 values = source.expression.expressions[0].expressions 271 272 if not values: 273 continue 274 275 selects[name] = { 276 alias: column 277 for alias, column in zip( 278 source.expression.alias_column_names, 279 values, 280 ) 281 } 282 else: 283 selects[name] = { 284 select.alias_or_name: select for select in source.expression.selects 285 } 286 # First annotate the current scope's column references 287 for col in scope.columns: 288 if not col.table: 289 continue 290 291 source = scope.sources.get(col.table) 292 if isinstance(source, exp.Table): 293 col.type = self.schema.get_column_type(source, col) 294 elif source and col.table in selects and col.name in selects[col.table]: 295 col.type = selects[col.table][col.name].type 296 # Then (possibly) annotate the remaining expressions in the scope 297 self._maybe_annotate(scope.expression) 298 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 299 300 def _maybe_annotate(self, expression): 301 if expression.type: 302 return expression # We've already inferred the expression's type 303 304 annotator = self.annotators.get(expression.__class__) 305 306 return ( 307 annotator(self, expression) 308 if annotator 309 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 310 ) 311 312 def _annotate_args(self, expression): 313 for _, value in expression.iter_expressions(): 314 self._maybe_annotate(value) 315 316 return expression 317 318 def _maybe_coerce(self, type1, type2): 319 # We propagate the NULL / UNKNOWN types upwards if found 320 if isinstance(type1, exp.DataType): 321 type1 = type1.this 322 if isinstance(type2, exp.DataType): 323 type2 = type2.this 324 325 if exp.DataType.Type.NULL in (type1, type2): 326 return exp.DataType.Type.NULL 327 if exp.DataType.Type.UNKNOWN in (type1, type2): 328 return exp.DataType.Type.UNKNOWN 329 330 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 331 332 def _annotate_binary(self, expression): 333 self._annotate_args(expression) 334 335 left_type = expression.left.type.this 336 right_type = expression.right.type.this 337 338 if isinstance(expression, (exp.And, exp.Or)): 339 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 340 expression.type = exp.DataType.Type.NULL 341 elif exp.DataType.Type.NULL in (left_type, right_type): 342 expression.type = exp.DataType.build( 343 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 344 ) 345 else: 346 expression.type = exp.DataType.Type.BOOLEAN 347 elif isinstance(expression, (exp.Condition, exp.Predicate)): 348 expression.type = exp.DataType.Type.BOOLEAN 349 else: 350 expression.type = self._maybe_coerce(left_type, right_type) 351 352 return expression 353 354 def _annotate_unary(self, expression): 355 self._annotate_args(expression) 356 357 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 358 expression.type = exp.DataType.Type.BOOLEAN 359 else: 360 expression.type = expression.this.type 361 362 return expression 363 364 def _annotate_literal(self, expression): 365 if expression.is_string: 366 expression.type = exp.DataType.Type.VARCHAR 367 elif expression.is_int: 368 expression.type = exp.DataType.Type.INT 369 else: 370 expression.type = exp.DataType.Type.DOUBLE 371 372 return expression 373 374 def _annotate_with_type(self, expression, target_type): 375 expression.type = target_type 376 return self._annotate_args(expression) 377 378 def _annotate_by_args(self, expression, *args, promote=False): 379 self._annotate_args(expression) 380 expressions = [] 381 for arg in args: 382 arg_expr = expression.args.get(arg) 383 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 384 385 last_datatype = None 386 for expr in expressions: 387 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 388 389 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 390 391 if promote: 392 if expression.type.this in exp.DataType.INTEGER_TYPES: 393 expression.type = exp.DataType.Type.BIGINT 394 elif expression.type.this in exp.DataType.FLOAT_TYPES: 395 expression.type = exp.DataType.Type.DOUBLE 396 397 return expression
def
annotate_types(expression, schema=None, annotators=None, coerces_to=None):
8def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 9 """ 10 Recursively infer & annotate types in an expression syntax tree against a schema. 11 Assumes that we've already executed the optimizer's qualify_columns step. 12 13 Example: 14 >>> import sqlglot 15 >>> schema = {"y": {"cola": "SMALLINT"}} 16 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 17 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 18 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 19 <Type.DOUBLE: 'DOUBLE'> 20 21 Args: 22 expression (sqlglot.Expression): Expression to annotate. 23 schema (dict|sqlglot.optimizer.Schema): Database schema. 24 annotators (dict): Maps expression type to corresponding annotation function. 25 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 26 Returns: 27 sqlglot.Expression: expression annotated with types 28 """ 29 30 schema = ensure_schema(schema) 31 32 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
class
TypeAnnotator:
35class TypeAnnotator: 36 ANNOTATORS = { 37 **{ 38 expr_type: lambda self, expr: self._annotate_unary(expr) 39 for expr_type in subclasses(exp.__name__, exp.Unary) 40 }, 41 **{ 42 expr_type: lambda self, expr: self._annotate_binary(expr) 43 for expr_type in subclasses(exp.__name__, exp.Binary) 44 }, 45 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 47 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 48 exp.Alias: lambda self, expr: self._annotate_unary(expr), 49 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 51 exp.Literal: lambda self, expr: self._annotate_literal(expr), 52 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 53 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 54 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 55 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 56 expr, exp.DataType.Type.BIGINT 57 ), 58 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 59 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 61 exp.Sum: lambda self, expr: self._annotate_by_args( 62 expr, "this", "expressions", promote=True 63 ), 64 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 65 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 66 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 67 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 68 expr, exp.DataType.Type.DATETIME 69 ), 70 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 71 expr, exp.DataType.Type.TIMESTAMP 72 ), 73 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 74 expr, exp.DataType.Type.TIMESTAMP 75 ), 76 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 78 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 79 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 80 expr, exp.DataType.Type.DATETIME 81 ), 82 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 83 expr, exp.DataType.Type.DATETIME 84 ), 85 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 87 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 88 expr, exp.DataType.Type.TIMESTAMP 89 ), 90 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 91 expr, exp.DataType.Type.TIMESTAMP 92 ), 93 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 94 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 96 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 97 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 98 expr, exp.DataType.Type.DATE 99 ), 100 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 101 expr, exp.DataType.Type.VARCHAR 102 ), 103 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 104 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 105 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 106 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 107 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 108 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 109 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 110 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 111 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 112 exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 113 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 114 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 115 expr, exp.DataType.Type.VARCHAR 116 ), 117 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 118 expr, exp.DataType.Type.VARCHAR 119 ), 120 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 121 exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), 122 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 123 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 124 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 125 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 126 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 127 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 128 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 129 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 130 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 131 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 132 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 133 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 134 expr, exp.DataType.Type.DOUBLE 135 ), 136 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 137 expr, exp.DataType.Type.BOOLEAN 138 ), 139 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 140 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 141 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 142 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 143 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 144 exp.StrToTime: lambda self, expr: self._annotate_with_type( 145 expr, exp.DataType.Type.TIMESTAMP 146 ), 147 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 148 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 149 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 150 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 151 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 152 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 153 expr, exp.DataType.Type.VARCHAR 154 ), 155 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 156 expr, exp.DataType.Type.DATE 157 ), 158 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 159 expr, exp.DataType.Type.TIMESTAMP 160 ), 161 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 162 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 163 expr, exp.DataType.Type.VARCHAR 164 ), 165 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 166 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 167 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 168 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 169 expr, exp.DataType.Type.TIMESTAMP 170 ), 171 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 172 expr, exp.DataType.Type.VARCHAR 173 ), 174 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 175 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 176 exp.VariancePop: lambda self, expr: self._annotate_with_type( 177 expr, exp.DataType.Type.DOUBLE 178 ), 179 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 180 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 181 } 182 183 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 184 COERCES_TO = { 185 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 186 exp.DataType.Type.TEXT: set(), 187 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 188 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 189 exp.DataType.Type.NCHAR: { 190 exp.DataType.Type.VARCHAR, 191 exp.DataType.Type.NVARCHAR, 192 exp.DataType.Type.TEXT, 193 }, 194 exp.DataType.Type.CHAR: { 195 exp.DataType.Type.NCHAR, 196 exp.DataType.Type.VARCHAR, 197 exp.DataType.Type.NVARCHAR, 198 exp.DataType.Type.TEXT, 199 }, 200 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 201 exp.DataType.Type.DOUBLE: set(), 202 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 203 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 204 exp.DataType.Type.BIGINT: { 205 exp.DataType.Type.DECIMAL, 206 exp.DataType.Type.FLOAT, 207 exp.DataType.Type.DOUBLE, 208 }, 209 exp.DataType.Type.INT: { 210 exp.DataType.Type.BIGINT, 211 exp.DataType.Type.DECIMAL, 212 exp.DataType.Type.FLOAT, 213 exp.DataType.Type.DOUBLE, 214 }, 215 exp.DataType.Type.SMALLINT: { 216 exp.DataType.Type.INT, 217 exp.DataType.Type.BIGINT, 218 exp.DataType.Type.DECIMAL, 219 exp.DataType.Type.FLOAT, 220 exp.DataType.Type.DOUBLE, 221 }, 222 exp.DataType.Type.TINYINT: { 223 exp.DataType.Type.SMALLINT, 224 exp.DataType.Type.INT, 225 exp.DataType.Type.BIGINT, 226 exp.DataType.Type.DECIMAL, 227 exp.DataType.Type.FLOAT, 228 exp.DataType.Type.DOUBLE, 229 }, 230 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 231 exp.DataType.Type.TIMESTAMPLTZ: set(), 232 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 233 exp.DataType.Type.TIMESTAMP: { 234 exp.DataType.Type.TIMESTAMPTZ, 235 exp.DataType.Type.TIMESTAMPLTZ, 236 }, 237 exp.DataType.Type.DATETIME: { 238 exp.DataType.Type.TIMESTAMP, 239 exp.DataType.Type.TIMESTAMPTZ, 240 exp.DataType.Type.TIMESTAMPLTZ, 241 }, 242 exp.DataType.Type.DATE: { 243 exp.DataType.Type.DATETIME, 244 exp.DataType.Type.TIMESTAMP, 245 exp.DataType.Type.TIMESTAMPTZ, 246 exp.DataType.Type.TIMESTAMPLTZ, 247 }, 248 } 249 250 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 251 252 def __init__(self, schema=None, annotators=None, coerces_to=None): 253 self.schema = schema 254 self.annotators = annotators or self.ANNOTATORS 255 self.coerces_to = coerces_to or self.COERCES_TO 256 257 def annotate(self, expression): 258 if isinstance(expression, self.TRAVERSABLES): 259 for scope in traverse_scope(expression): 260 selects = {} 261 for name, source in scope.sources.items(): 262 if not isinstance(source, Scope): 263 continue 264 if isinstance(source.expression, exp.UDTF): 265 values = [] 266 267 if isinstance(source.expression, exp.Lateral): 268 if isinstance(source.expression.this, exp.Explode): 269 values = [source.expression.this.this] 270 else: 271 values = source.expression.expressions[0].expressions 272 273 if not values: 274 continue 275 276 selects[name] = { 277 alias: column 278 for alias, column in zip( 279 source.expression.alias_column_names, 280 values, 281 ) 282 } 283 else: 284 selects[name] = { 285 select.alias_or_name: select for select in source.expression.selects 286 } 287 # First annotate the current scope's column references 288 for col in scope.columns: 289 if not col.table: 290 continue 291 292 source = scope.sources.get(col.table) 293 if isinstance(source, exp.Table): 294 col.type = self.schema.get_column_type(source, col) 295 elif source and col.table in selects and col.name in selects[col.table]: 296 col.type = selects[col.table][col.name].type 297 # Then (possibly) annotate the remaining expressions in the scope 298 self._maybe_annotate(scope.expression) 299 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 300 301 def _maybe_annotate(self, expression): 302 if expression.type: 303 return expression # We've already inferred the expression's type 304 305 annotator = self.annotators.get(expression.__class__) 306 307 return ( 308 annotator(self, expression) 309 if annotator 310 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 311 ) 312 313 def _annotate_args(self, expression): 314 for _, value in expression.iter_expressions(): 315 self._maybe_annotate(value) 316 317 return expression 318 319 def _maybe_coerce(self, type1, type2): 320 # We propagate the NULL / UNKNOWN types upwards if found 321 if isinstance(type1, exp.DataType): 322 type1 = type1.this 323 if isinstance(type2, exp.DataType): 324 type2 = type2.this 325 326 if exp.DataType.Type.NULL in (type1, type2): 327 return exp.DataType.Type.NULL 328 if exp.DataType.Type.UNKNOWN in (type1, type2): 329 return exp.DataType.Type.UNKNOWN 330 331 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 332 333 def _annotate_binary(self, expression): 334 self._annotate_args(expression) 335 336 left_type = expression.left.type.this 337 right_type = expression.right.type.this 338 339 if isinstance(expression, (exp.And, exp.Or)): 340 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 341 expression.type = exp.DataType.Type.NULL 342 elif exp.DataType.Type.NULL in (left_type, right_type): 343 expression.type = exp.DataType.build( 344 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 345 ) 346 else: 347 expression.type = exp.DataType.Type.BOOLEAN 348 elif isinstance(expression, (exp.Condition, exp.Predicate)): 349 expression.type = exp.DataType.Type.BOOLEAN 350 else: 351 expression.type = self._maybe_coerce(left_type, right_type) 352 353 return expression 354 355 def _annotate_unary(self, expression): 356 self._annotate_args(expression) 357 358 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 359 expression.type = exp.DataType.Type.BOOLEAN 360 else: 361 expression.type = expression.this.type 362 363 return expression 364 365 def _annotate_literal(self, expression): 366 if expression.is_string: 367 expression.type = exp.DataType.Type.VARCHAR 368 elif expression.is_int: 369 expression.type = exp.DataType.Type.INT 370 else: 371 expression.type = exp.DataType.Type.DOUBLE 372 373 return expression 374 375 def _annotate_with_type(self, expression, target_type): 376 expression.type = target_type 377 return self._annotate_args(expression) 378 379 def _annotate_by_args(self, expression, *args, promote=False): 380 self._annotate_args(expression) 381 expressions = [] 382 for arg in args: 383 arg_expr = expression.args.get(arg) 384 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 385 386 last_datatype = None 387 for expr in expressions: 388 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 389 390 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 391 392 if promote: 393 if expression.type.this in exp.DataType.INTEGER_TYPES: 394 expression.type = exp.DataType.Type.BIGINT 395 elif expression.type.this in exp.DataType.FLOAT_TYPES: 396 expression.type = exp.DataType.Type.DOUBLE 397 398 return expression
def
annotate(self, expression):
257 def annotate(self, expression): 258 if isinstance(expression, self.TRAVERSABLES): 259 for scope in traverse_scope(expression): 260 selects = {} 261 for name, source in scope.sources.items(): 262 if not isinstance(source, Scope): 263 continue 264 if isinstance(source.expression, exp.UDTF): 265 values = [] 266 267 if isinstance(source.expression, exp.Lateral): 268 if isinstance(source.expression.this, exp.Explode): 269 values = [source.expression.this.this] 270 else: 271 values = source.expression.expressions[0].expressions 272 273 if not values: 274 continue 275 276 selects[name] = { 277 alias: column 278 for alias, column in zip( 279 source.expression.alias_column_names, 280 values, 281 ) 282 } 283 else: 284 selects[name] = { 285 select.alias_or_name: select for select in source.expression.selects 286 } 287 # First annotate the current scope's column references 288 for col in scope.columns: 289 if not col.table: 290 continue 291 292 source = scope.sources.get(col.table) 293 if isinstance(source, exp.Table): 294 col.type = self.schema.get_column_type(source, col) 295 elif source and col.table in selects and col.name in selects[col.table]: 296 col.type = selects[col.table][col.name].type 297 # Then (possibly) annotate the remaining expressions in the scope 298 self._maybe_annotate(scope.expression) 299 return self._maybe_annotate(expression) # This takes care of non-traversable expressions