sqlglot.optimizer.annotate_types
1from sqlglot import exp 2from sqlglot.helper import ensure_list, subclasses 3from sqlglot.optimizer.scope import Scope, traverse_scope 4from sqlglot.schema import ensure_schema 5 6 7def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 8 """ 9 Recursively infer & annotate types in an expression syntax tree against a schema. 10 Assumes that we've already executed the optimizer's qualify_columns step. 11 12 Example: 13 >>> import sqlglot 14 >>> schema = {"y": {"cola": "SMALLINT"}} 15 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 16 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 17 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 18 <Type.DOUBLE: 'DOUBLE'> 19 20 Args: 21 expression (sqlglot.Expression): Expression to annotate. 22 schema (dict|sqlglot.optimizer.Schema): Database schema. 23 annotators (dict): Maps expression type to corresponding annotation function. 24 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 25 Returns: 26 sqlglot.Expression: expression annotated with types 27 """ 28 29 schema = ensure_schema(schema) 30 31 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 32 33 34class TypeAnnotator: 35 ANNOTATORS = { 36 **{ 37 expr_type: lambda self, expr: self._annotate_unary(expr) 38 for expr_type in subclasses(exp.__name__, exp.Unary) 39 }, 40 **{ 41 expr_type: lambda self, expr: self._annotate_binary(expr) 42 for expr_type in subclasses(exp.__name__, exp.Binary) 43 }, 44 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 45 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 47 exp.Alias: lambda self, expr: self._annotate_unary(expr), 48 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 49 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.Literal: lambda self, expr: self._annotate_literal(expr), 51 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 52 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 53 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 54 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 55 expr, exp.DataType.Type.BIGINT 56 ), 57 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 58 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 59 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Sum: lambda self, expr: self._annotate_by_args( 61 expr, "this", "expressions", promote=True 62 ), 63 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 64 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 65 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 66 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 67 expr, exp.DataType.Type.DATETIME 68 ), 69 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 70 expr, exp.DataType.Type.TIMESTAMP 71 ), 72 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 73 expr, exp.DataType.Type.TIMESTAMP 74 ), 75 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 76 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 78 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 79 expr, exp.DataType.Type.DATETIME 80 ), 81 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 82 expr, exp.DataType.Type.DATETIME 83 ), 84 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 85 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 87 expr, exp.DataType.Type.TIMESTAMP 88 ), 89 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 90 expr, exp.DataType.Type.TIMESTAMP 91 ), 92 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 93 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 94 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 96 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 97 expr, exp.DataType.Type.DATE 98 ), 99 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 100 expr, exp.DataType.Type.VARCHAR 101 ), 102 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 103 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 104 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 105 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 106 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 107 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 108 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 109 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 110 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 111 exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 112 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 113 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 114 expr, exp.DataType.Type.VARCHAR 115 ), 116 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 117 expr, exp.DataType.Type.VARCHAR 118 ), 119 exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 120 exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 121 exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 122 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 123 exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), 124 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 125 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 126 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 127 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 128 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 129 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 130 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 131 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 132 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 133 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 134 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 135 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 136 expr, exp.DataType.Type.DOUBLE 137 ), 138 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 139 expr, exp.DataType.Type.BOOLEAN 140 ), 141 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 142 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 143 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 144 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 145 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 146 exp.StrToTime: lambda self, expr: self._annotate_with_type( 147 expr, exp.DataType.Type.TIMESTAMP 148 ), 149 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 150 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 151 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 152 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 153 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 154 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 155 expr, exp.DataType.Type.VARCHAR 156 ), 157 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 158 expr, exp.DataType.Type.DATE 159 ), 160 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 161 expr, exp.DataType.Type.TIMESTAMP 162 ), 163 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 164 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 165 expr, exp.DataType.Type.VARCHAR 166 ), 167 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 168 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 169 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 170 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 171 expr, exp.DataType.Type.TIMESTAMP 172 ), 173 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 174 expr, exp.DataType.Type.VARCHAR 175 ), 176 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 177 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 178 exp.VariancePop: lambda self, expr: self._annotate_with_type( 179 expr, exp.DataType.Type.DOUBLE 180 ), 181 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 182 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 183 } 184 185 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 186 COERCES_TO = { 187 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 188 exp.DataType.Type.TEXT: set(), 189 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 190 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 191 exp.DataType.Type.NCHAR: { 192 exp.DataType.Type.VARCHAR, 193 exp.DataType.Type.NVARCHAR, 194 exp.DataType.Type.TEXT, 195 }, 196 exp.DataType.Type.CHAR: { 197 exp.DataType.Type.NCHAR, 198 exp.DataType.Type.VARCHAR, 199 exp.DataType.Type.NVARCHAR, 200 exp.DataType.Type.TEXT, 201 }, 202 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 203 exp.DataType.Type.DOUBLE: set(), 204 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 205 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 206 exp.DataType.Type.BIGINT: { 207 exp.DataType.Type.DECIMAL, 208 exp.DataType.Type.FLOAT, 209 exp.DataType.Type.DOUBLE, 210 }, 211 exp.DataType.Type.INT: { 212 exp.DataType.Type.BIGINT, 213 exp.DataType.Type.DECIMAL, 214 exp.DataType.Type.FLOAT, 215 exp.DataType.Type.DOUBLE, 216 }, 217 exp.DataType.Type.SMALLINT: { 218 exp.DataType.Type.INT, 219 exp.DataType.Type.BIGINT, 220 exp.DataType.Type.DECIMAL, 221 exp.DataType.Type.FLOAT, 222 exp.DataType.Type.DOUBLE, 223 }, 224 exp.DataType.Type.TINYINT: { 225 exp.DataType.Type.SMALLINT, 226 exp.DataType.Type.INT, 227 exp.DataType.Type.BIGINT, 228 exp.DataType.Type.DECIMAL, 229 exp.DataType.Type.FLOAT, 230 exp.DataType.Type.DOUBLE, 231 }, 232 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 233 exp.DataType.Type.TIMESTAMPLTZ: set(), 234 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 235 exp.DataType.Type.TIMESTAMP: { 236 exp.DataType.Type.TIMESTAMPTZ, 237 exp.DataType.Type.TIMESTAMPLTZ, 238 }, 239 exp.DataType.Type.DATETIME: { 240 exp.DataType.Type.TIMESTAMP, 241 exp.DataType.Type.TIMESTAMPTZ, 242 exp.DataType.Type.TIMESTAMPLTZ, 243 }, 244 exp.DataType.Type.DATE: { 245 exp.DataType.Type.DATETIME, 246 exp.DataType.Type.TIMESTAMP, 247 exp.DataType.Type.TIMESTAMPTZ, 248 exp.DataType.Type.TIMESTAMPLTZ, 249 }, 250 } 251 252 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 253 254 def __init__(self, schema=None, annotators=None, coerces_to=None): 255 self.schema = schema 256 self.annotators = annotators or self.ANNOTATORS 257 self.coerces_to = coerces_to or self.COERCES_TO 258 259 def annotate(self, expression): 260 if isinstance(expression, self.TRAVERSABLES): 261 for scope in traverse_scope(expression): 262 selects = {} 263 for name, source in scope.sources.items(): 264 if not isinstance(source, Scope): 265 continue 266 if isinstance(source.expression, exp.UDTF): 267 values = [] 268 269 if isinstance(source.expression, exp.Lateral): 270 if isinstance(source.expression.this, exp.Explode): 271 values = [source.expression.this.this] 272 else: 273 values = source.expression.expressions[0].expressions 274 275 if not values: 276 continue 277 278 selects[name] = { 279 alias: column 280 for alias, column in zip( 281 source.expression.alias_column_names, 282 values, 283 ) 284 } 285 else: 286 selects[name] = { 287 select.alias_or_name: select for select in source.expression.selects 288 } 289 # First annotate the current scope's column references 290 for col in scope.columns: 291 if not col.table: 292 continue 293 294 source = scope.sources.get(col.table) 295 if isinstance(source, exp.Table): 296 col.type = self.schema.get_column_type(source, col) 297 elif source and col.table in selects and col.name in selects[col.table]: 298 col.type = selects[col.table][col.name].type 299 # Then (possibly) annotate the remaining expressions in the scope 300 self._maybe_annotate(scope.expression) 301 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 302 303 def _maybe_annotate(self, expression): 304 if expression.type: 305 return expression # We've already inferred the expression's type 306 307 annotator = self.annotators.get(expression.__class__) 308 309 return ( 310 annotator(self, expression) 311 if annotator 312 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 313 ) 314 315 def _annotate_args(self, expression): 316 for _, value in expression.iter_expressions(): 317 self._maybe_annotate(value) 318 319 return expression 320 321 def _maybe_coerce(self, type1, type2): 322 # We propagate the NULL / UNKNOWN types upwards if found 323 if isinstance(type1, exp.DataType): 324 type1 = type1.this 325 if isinstance(type2, exp.DataType): 326 type2 = type2.this 327 328 if exp.DataType.Type.NULL in (type1, type2): 329 return exp.DataType.Type.NULL 330 if exp.DataType.Type.UNKNOWN in (type1, type2): 331 return exp.DataType.Type.UNKNOWN 332 333 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 334 335 def _annotate_binary(self, expression): 336 self._annotate_args(expression) 337 338 left_type = expression.left.type.this 339 right_type = expression.right.type.this 340 341 if isinstance(expression, (exp.And, exp.Or)): 342 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 343 expression.type = exp.DataType.Type.NULL 344 elif exp.DataType.Type.NULL in (left_type, right_type): 345 expression.type = exp.DataType.build( 346 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 347 ) 348 else: 349 expression.type = exp.DataType.Type.BOOLEAN 350 elif isinstance(expression, (exp.Condition, exp.Predicate)): 351 expression.type = exp.DataType.Type.BOOLEAN 352 else: 353 expression.type = self._maybe_coerce(left_type, right_type) 354 355 return expression 356 357 def _annotate_unary(self, expression): 358 self._annotate_args(expression) 359 360 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 361 expression.type = exp.DataType.Type.BOOLEAN 362 else: 363 expression.type = expression.this.type 364 365 return expression 366 367 def _annotate_literal(self, expression): 368 if expression.is_string: 369 expression.type = exp.DataType.Type.VARCHAR 370 elif expression.is_int: 371 expression.type = exp.DataType.Type.INT 372 else: 373 expression.type = exp.DataType.Type.DOUBLE 374 375 return expression 376 377 def _annotate_with_type(self, expression, target_type): 378 expression.type = target_type 379 return self._annotate_args(expression) 380 381 def _annotate_by_args(self, expression, *args, promote=False): 382 self._annotate_args(expression) 383 expressions = [] 384 for arg in args: 385 arg_expr = expression.args.get(arg) 386 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 387 388 last_datatype = None 389 for expr in expressions: 390 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 391 392 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 393 394 if promote: 395 if expression.type.this in exp.DataType.INTEGER_TYPES: 396 expression.type = exp.DataType.Type.BIGINT 397 elif expression.type.this in exp.DataType.FLOAT_TYPES: 398 expression.type = exp.DataType.Type.DOUBLE 399 400 return expression
def
annotate_types(expression, schema=None, annotators=None, coerces_to=None):
8def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 9 """ 10 Recursively infer & annotate types in an expression syntax tree against a schema. 11 Assumes that we've already executed the optimizer's qualify_columns step. 12 13 Example: 14 >>> import sqlglot 15 >>> schema = {"y": {"cola": "SMALLINT"}} 16 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 17 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 18 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 19 <Type.DOUBLE: 'DOUBLE'> 20 21 Args: 22 expression (sqlglot.Expression): Expression to annotate. 23 schema (dict|sqlglot.optimizer.Schema): Database schema. 24 annotators (dict): Maps expression type to corresponding annotation function. 25 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 26 Returns: 27 sqlglot.Expression: expression annotated with types 28 """ 29 30 schema = ensure_schema(schema) 31 32 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
class
TypeAnnotator:
35class TypeAnnotator: 36 ANNOTATORS = { 37 **{ 38 expr_type: lambda self, expr: self._annotate_unary(expr) 39 for expr_type in subclasses(exp.__name__, exp.Unary) 40 }, 41 **{ 42 expr_type: lambda self, expr: self._annotate_binary(expr) 43 for expr_type in subclasses(exp.__name__, exp.Binary) 44 }, 45 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 47 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 48 exp.Alias: lambda self, expr: self._annotate_unary(expr), 49 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 51 exp.Literal: lambda self, expr: self._annotate_literal(expr), 52 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 53 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 54 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 55 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 56 expr, exp.DataType.Type.BIGINT 57 ), 58 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 59 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 61 exp.Sum: lambda self, expr: self._annotate_by_args( 62 expr, "this", "expressions", promote=True 63 ), 64 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 65 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 66 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 67 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 68 expr, exp.DataType.Type.DATETIME 69 ), 70 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 71 expr, exp.DataType.Type.TIMESTAMP 72 ), 73 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 74 expr, exp.DataType.Type.TIMESTAMP 75 ), 76 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 78 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 79 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 80 expr, exp.DataType.Type.DATETIME 81 ), 82 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 83 expr, exp.DataType.Type.DATETIME 84 ), 85 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 87 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 88 expr, exp.DataType.Type.TIMESTAMP 89 ), 90 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 91 expr, exp.DataType.Type.TIMESTAMP 92 ), 93 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 94 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 96 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 97 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 98 expr, exp.DataType.Type.DATE 99 ), 100 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 101 expr, exp.DataType.Type.VARCHAR 102 ), 103 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 104 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 105 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 106 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 107 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 108 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 109 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 110 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 111 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 112 exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 113 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 114 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 115 expr, exp.DataType.Type.VARCHAR 116 ), 117 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 118 expr, exp.DataType.Type.VARCHAR 119 ), 120 exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 121 exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 122 exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), 123 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 124 exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), 125 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 126 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 127 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 128 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 129 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 130 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 131 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 132 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 133 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 134 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 135 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 136 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 137 expr, exp.DataType.Type.DOUBLE 138 ), 139 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 140 expr, exp.DataType.Type.BOOLEAN 141 ), 142 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 143 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 144 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 145 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 146 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 147 exp.StrToTime: lambda self, expr: self._annotate_with_type( 148 expr, exp.DataType.Type.TIMESTAMP 149 ), 150 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 151 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 152 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 153 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 154 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 155 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 156 expr, exp.DataType.Type.VARCHAR 157 ), 158 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 159 expr, exp.DataType.Type.DATE 160 ), 161 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 162 expr, exp.DataType.Type.TIMESTAMP 163 ), 164 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 165 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 166 expr, exp.DataType.Type.VARCHAR 167 ), 168 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 169 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 170 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 171 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 172 expr, exp.DataType.Type.TIMESTAMP 173 ), 174 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 175 expr, exp.DataType.Type.VARCHAR 176 ), 177 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 178 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 179 exp.VariancePop: lambda self, expr: self._annotate_with_type( 180 expr, exp.DataType.Type.DOUBLE 181 ), 182 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 183 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 184 } 185 186 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 187 COERCES_TO = { 188 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 189 exp.DataType.Type.TEXT: set(), 190 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 191 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 192 exp.DataType.Type.NCHAR: { 193 exp.DataType.Type.VARCHAR, 194 exp.DataType.Type.NVARCHAR, 195 exp.DataType.Type.TEXT, 196 }, 197 exp.DataType.Type.CHAR: { 198 exp.DataType.Type.NCHAR, 199 exp.DataType.Type.VARCHAR, 200 exp.DataType.Type.NVARCHAR, 201 exp.DataType.Type.TEXT, 202 }, 203 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 204 exp.DataType.Type.DOUBLE: set(), 205 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 206 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 207 exp.DataType.Type.BIGINT: { 208 exp.DataType.Type.DECIMAL, 209 exp.DataType.Type.FLOAT, 210 exp.DataType.Type.DOUBLE, 211 }, 212 exp.DataType.Type.INT: { 213 exp.DataType.Type.BIGINT, 214 exp.DataType.Type.DECIMAL, 215 exp.DataType.Type.FLOAT, 216 exp.DataType.Type.DOUBLE, 217 }, 218 exp.DataType.Type.SMALLINT: { 219 exp.DataType.Type.INT, 220 exp.DataType.Type.BIGINT, 221 exp.DataType.Type.DECIMAL, 222 exp.DataType.Type.FLOAT, 223 exp.DataType.Type.DOUBLE, 224 }, 225 exp.DataType.Type.TINYINT: { 226 exp.DataType.Type.SMALLINT, 227 exp.DataType.Type.INT, 228 exp.DataType.Type.BIGINT, 229 exp.DataType.Type.DECIMAL, 230 exp.DataType.Type.FLOAT, 231 exp.DataType.Type.DOUBLE, 232 }, 233 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 234 exp.DataType.Type.TIMESTAMPLTZ: set(), 235 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 236 exp.DataType.Type.TIMESTAMP: { 237 exp.DataType.Type.TIMESTAMPTZ, 238 exp.DataType.Type.TIMESTAMPLTZ, 239 }, 240 exp.DataType.Type.DATETIME: { 241 exp.DataType.Type.TIMESTAMP, 242 exp.DataType.Type.TIMESTAMPTZ, 243 exp.DataType.Type.TIMESTAMPLTZ, 244 }, 245 exp.DataType.Type.DATE: { 246 exp.DataType.Type.DATETIME, 247 exp.DataType.Type.TIMESTAMP, 248 exp.DataType.Type.TIMESTAMPTZ, 249 exp.DataType.Type.TIMESTAMPLTZ, 250 }, 251 } 252 253 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 254 255 def __init__(self, schema=None, annotators=None, coerces_to=None): 256 self.schema = schema 257 self.annotators = annotators or self.ANNOTATORS 258 self.coerces_to = coerces_to or self.COERCES_TO 259 260 def annotate(self, expression): 261 if isinstance(expression, self.TRAVERSABLES): 262 for scope in traverse_scope(expression): 263 selects = {} 264 for name, source in scope.sources.items(): 265 if not isinstance(source, Scope): 266 continue 267 if isinstance(source.expression, exp.UDTF): 268 values = [] 269 270 if isinstance(source.expression, exp.Lateral): 271 if isinstance(source.expression.this, exp.Explode): 272 values = [source.expression.this.this] 273 else: 274 values = source.expression.expressions[0].expressions 275 276 if not values: 277 continue 278 279 selects[name] = { 280 alias: column 281 for alias, column in zip( 282 source.expression.alias_column_names, 283 values, 284 ) 285 } 286 else: 287 selects[name] = { 288 select.alias_or_name: select for select in source.expression.selects 289 } 290 # First annotate the current scope's column references 291 for col in scope.columns: 292 if not col.table: 293 continue 294 295 source = scope.sources.get(col.table) 296 if isinstance(source, exp.Table): 297 col.type = self.schema.get_column_type(source, col) 298 elif source and col.table in selects and col.name in selects[col.table]: 299 col.type = selects[col.table][col.name].type 300 # Then (possibly) annotate the remaining expressions in the scope 301 self._maybe_annotate(scope.expression) 302 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 303 304 def _maybe_annotate(self, expression): 305 if expression.type: 306 return expression # We've already inferred the expression's type 307 308 annotator = self.annotators.get(expression.__class__) 309 310 return ( 311 annotator(self, expression) 312 if annotator 313 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 314 ) 315 316 def _annotate_args(self, expression): 317 for _, value in expression.iter_expressions(): 318 self._maybe_annotate(value) 319 320 return expression 321 322 def _maybe_coerce(self, type1, type2): 323 # We propagate the NULL / UNKNOWN types upwards if found 324 if isinstance(type1, exp.DataType): 325 type1 = type1.this 326 if isinstance(type2, exp.DataType): 327 type2 = type2.this 328 329 if exp.DataType.Type.NULL in (type1, type2): 330 return exp.DataType.Type.NULL 331 if exp.DataType.Type.UNKNOWN in (type1, type2): 332 return exp.DataType.Type.UNKNOWN 333 334 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 335 336 def _annotate_binary(self, expression): 337 self._annotate_args(expression) 338 339 left_type = expression.left.type.this 340 right_type = expression.right.type.this 341 342 if isinstance(expression, (exp.And, exp.Or)): 343 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 344 expression.type = exp.DataType.Type.NULL 345 elif exp.DataType.Type.NULL in (left_type, right_type): 346 expression.type = exp.DataType.build( 347 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 348 ) 349 else: 350 expression.type = exp.DataType.Type.BOOLEAN 351 elif isinstance(expression, (exp.Condition, exp.Predicate)): 352 expression.type = exp.DataType.Type.BOOLEAN 353 else: 354 expression.type = self._maybe_coerce(left_type, right_type) 355 356 return expression 357 358 def _annotate_unary(self, expression): 359 self._annotate_args(expression) 360 361 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 362 expression.type = exp.DataType.Type.BOOLEAN 363 else: 364 expression.type = expression.this.type 365 366 return expression 367 368 def _annotate_literal(self, expression): 369 if expression.is_string: 370 expression.type = exp.DataType.Type.VARCHAR 371 elif expression.is_int: 372 expression.type = exp.DataType.Type.INT 373 else: 374 expression.type = exp.DataType.Type.DOUBLE 375 376 return expression 377 378 def _annotate_with_type(self, expression, target_type): 379 expression.type = target_type 380 return self._annotate_args(expression) 381 382 def _annotate_by_args(self, expression, *args, promote=False): 383 self._annotate_args(expression) 384 expressions = [] 385 for arg in args: 386 arg_expr = expression.args.get(arg) 387 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 388 389 last_datatype = None 390 for expr in expressions: 391 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 392 393 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 394 395 if promote: 396 if expression.type.this in exp.DataType.INTEGER_TYPES: 397 expression.type = exp.DataType.Type.BIGINT 398 elif expression.type.this in exp.DataType.FLOAT_TYPES: 399 expression.type = exp.DataType.Type.DOUBLE 400 401 return expression
def
annotate(self, expression):
260 def annotate(self, expression): 261 if isinstance(expression, self.TRAVERSABLES): 262 for scope in traverse_scope(expression): 263 selects = {} 264 for name, source in scope.sources.items(): 265 if not isinstance(source, Scope): 266 continue 267 if isinstance(source.expression, exp.UDTF): 268 values = [] 269 270 if isinstance(source.expression, exp.Lateral): 271 if isinstance(source.expression.this, exp.Explode): 272 values = [source.expression.this.this] 273 else: 274 values = source.expression.expressions[0].expressions 275 276 if not values: 277 continue 278 279 selects[name] = { 280 alias: column 281 for alias, column in zip( 282 source.expression.alias_column_names, 283 values, 284 ) 285 } 286 else: 287 selects[name] = { 288 select.alias_or_name: select for select in source.expression.selects 289 } 290 # First annotate the current scope's column references 291 for col in scope.columns: 292 if not col.table: 293 continue 294 295 source = scope.sources.get(col.table) 296 if isinstance(source, exp.Table): 297 col.type = self.schema.get_column_type(source, col) 298 elif source and col.table in selects and col.name in selects[col.table]: 299 col.type = selects[col.table][col.name].type 300 # Then (possibly) annotate the remaining expressions in the scope 301 self._maybe_annotate(scope.expression) 302 return self._maybe_annotate(expression) # This takes care of non-traversable expressions