sqlglot.optimizer.annotate_types
1from sqlglot import exp 2from sqlglot.helper import ensure_collection, ensure_list, subclasses 3from sqlglot.optimizer.scope import Scope, traverse_scope 4from sqlglot.schema import ensure_schema 5 6 7def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 8 """ 9 Recursively infer & annotate types in an expression syntax tree against a schema. 10 Assumes that we've already executed the optimizer's qualify_columns step. 11 12 Example: 13 >>> import sqlglot 14 >>> schema = {"y": {"cola": "SMALLINT"}} 15 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 16 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 17 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 18 <Type.DOUBLE: 'DOUBLE'> 19 20 Args: 21 expression (sqlglot.Expression): Expression to annotate. 22 schema (dict|sqlglot.optimizer.Schema): Database schema. 23 annotators (dict): Maps expression type to corresponding annotation function. 24 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 25 Returns: 26 sqlglot.Expression: expression annotated with types 27 """ 28 29 schema = ensure_schema(schema) 30 31 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 32 33 34class TypeAnnotator: 35 ANNOTATORS = { 36 **{ 37 expr_type: lambda self, expr: self._annotate_unary(expr) 38 for expr_type in subclasses(exp.__name__, exp.Unary) 39 }, 40 **{ 41 expr_type: lambda self, expr: self._annotate_binary(expr) 42 for expr_type in subclasses(exp.__name__, exp.Binary) 43 }, 44 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 45 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 47 exp.Alias: lambda self, expr: self._annotate_unary(expr), 48 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 49 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.Literal: lambda self, expr: self._annotate_literal(expr), 51 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 52 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 53 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 54 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 55 expr, exp.DataType.Type.BIGINT 56 ), 57 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 58 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), 59 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), 60 exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), 61 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 62 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 63 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 64 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 65 expr, exp.DataType.Type.DATETIME 66 ), 67 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 68 expr, exp.DataType.Type.TIMESTAMP 69 ), 70 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 71 expr, exp.DataType.Type.TIMESTAMP 72 ), 73 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 74 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 75 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 76 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 77 expr, exp.DataType.Type.DATETIME 78 ), 79 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 80 expr, exp.DataType.Type.DATETIME 81 ), 82 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 83 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 84 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 85 expr, exp.DataType.Type.TIMESTAMP 86 ), 87 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 88 expr, exp.DataType.Type.TIMESTAMP 89 ), 90 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 91 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 92 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 93 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 94 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 95 expr, exp.DataType.Type.DATE 96 ), 97 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 98 expr, exp.DataType.Type.VARCHAR 99 ), 100 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 101 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 102 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 103 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 104 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 105 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 106 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 107 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 108 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 109 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 110 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 111 expr, exp.DataType.Type.VARCHAR 112 ), 113 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 114 expr, exp.DataType.Type.VARCHAR 115 ), 116 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 117 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 118 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 119 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 120 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 121 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 122 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 123 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 124 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 125 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 126 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 127 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 128 expr, exp.DataType.Type.DOUBLE 129 ), 130 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 131 expr, exp.DataType.Type.BOOLEAN 132 ), 133 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 134 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 135 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 136 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 137 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 138 exp.StrToTime: lambda self, expr: self._annotate_with_type( 139 expr, exp.DataType.Type.TIMESTAMP 140 ), 141 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 142 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 143 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 144 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 145 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 146 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 147 expr, exp.DataType.Type.VARCHAR 148 ), 149 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 150 expr, exp.DataType.Type.DATE 151 ), 152 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 153 expr, exp.DataType.Type.TIMESTAMP 154 ), 155 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 156 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 157 expr, exp.DataType.Type.VARCHAR 158 ), 159 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 160 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 161 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 162 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 163 expr, exp.DataType.Type.TIMESTAMP 164 ), 165 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 166 expr, exp.DataType.Type.VARCHAR 167 ), 168 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 169 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 170 exp.VariancePop: lambda self, expr: self._annotate_with_type( 171 expr, exp.DataType.Type.DOUBLE 172 ), 173 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 174 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 175 } 176 177 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 178 COERCES_TO = { 179 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 180 exp.DataType.Type.TEXT: set(), 181 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 182 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 183 exp.DataType.Type.NCHAR: { 184 exp.DataType.Type.VARCHAR, 185 exp.DataType.Type.NVARCHAR, 186 exp.DataType.Type.TEXT, 187 }, 188 exp.DataType.Type.CHAR: { 189 exp.DataType.Type.NCHAR, 190 exp.DataType.Type.VARCHAR, 191 exp.DataType.Type.NVARCHAR, 192 exp.DataType.Type.TEXT, 193 }, 194 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 195 exp.DataType.Type.DOUBLE: set(), 196 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 197 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 198 exp.DataType.Type.BIGINT: { 199 exp.DataType.Type.DECIMAL, 200 exp.DataType.Type.FLOAT, 201 exp.DataType.Type.DOUBLE, 202 }, 203 exp.DataType.Type.INT: { 204 exp.DataType.Type.BIGINT, 205 exp.DataType.Type.DECIMAL, 206 exp.DataType.Type.FLOAT, 207 exp.DataType.Type.DOUBLE, 208 }, 209 exp.DataType.Type.SMALLINT: { 210 exp.DataType.Type.INT, 211 exp.DataType.Type.BIGINT, 212 exp.DataType.Type.DECIMAL, 213 exp.DataType.Type.FLOAT, 214 exp.DataType.Type.DOUBLE, 215 }, 216 exp.DataType.Type.TINYINT: { 217 exp.DataType.Type.SMALLINT, 218 exp.DataType.Type.INT, 219 exp.DataType.Type.BIGINT, 220 exp.DataType.Type.DECIMAL, 221 exp.DataType.Type.FLOAT, 222 exp.DataType.Type.DOUBLE, 223 }, 224 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 225 exp.DataType.Type.TIMESTAMPLTZ: set(), 226 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 227 exp.DataType.Type.TIMESTAMP: { 228 exp.DataType.Type.TIMESTAMPTZ, 229 exp.DataType.Type.TIMESTAMPLTZ, 230 }, 231 exp.DataType.Type.DATETIME: { 232 exp.DataType.Type.TIMESTAMP, 233 exp.DataType.Type.TIMESTAMPTZ, 234 exp.DataType.Type.TIMESTAMPLTZ, 235 }, 236 exp.DataType.Type.DATE: { 237 exp.DataType.Type.DATETIME, 238 exp.DataType.Type.TIMESTAMP, 239 exp.DataType.Type.TIMESTAMPTZ, 240 exp.DataType.Type.TIMESTAMPLTZ, 241 }, 242 } 243 244 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 245 246 def __init__(self, schema=None, annotators=None, coerces_to=None): 247 self.schema = schema 248 self.annotators = annotators or self.ANNOTATORS 249 self.coerces_to = coerces_to or self.COERCES_TO 250 251 def annotate(self, expression): 252 if isinstance(expression, self.TRAVERSABLES): 253 for scope in traverse_scope(expression): 254 selects = {} 255 for name, source in scope.sources.items(): 256 if not isinstance(source, Scope): 257 continue 258 if isinstance(source.expression, exp.UDTF): 259 values = [] 260 261 if isinstance(source.expression, exp.Lateral): 262 if isinstance(source.expression.this, exp.Explode): 263 values = [source.expression.this.this] 264 else: 265 values = source.expression.expressions[0].expressions 266 267 if not values: 268 continue 269 270 selects[name] = { 271 alias: column 272 for alias, column in zip( 273 source.expression.alias_column_names, 274 values, 275 ) 276 } 277 else: 278 selects[name] = { 279 select.alias_or_name: select for select in source.expression.selects 280 } 281 # First annotate the current scope's column references 282 for col in scope.columns: 283 if not col.table: 284 continue 285 286 source = scope.sources.get(col.table) 287 if isinstance(source, exp.Table): 288 col.type = self.schema.get_column_type(source, col) 289 elif source and col.table in selects: 290 col.type = selects[col.table][col.name].type 291 # Then (possibly) annotate the remaining expressions in the scope 292 self._maybe_annotate(scope.expression) 293 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 294 295 def _maybe_annotate(self, expression): 296 if not isinstance(expression, exp.Expression): 297 return None 298 299 if expression.type: 300 return expression # We've already inferred the expression's type 301 302 annotator = self.annotators.get(expression.__class__) 303 304 return ( 305 annotator(self, expression) 306 if annotator 307 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 308 ) 309 310 def _annotate_args(self, expression): 311 for value in expression.args.values(): 312 for v in ensure_collection(value): 313 self._maybe_annotate(v) 314 315 return expression 316 317 def _maybe_coerce(self, type1, type2): 318 # We propagate the NULL / UNKNOWN types upwards if found 319 if isinstance(type1, exp.DataType): 320 type1 = type1.this 321 if isinstance(type2, exp.DataType): 322 type2 = type2.this 323 324 if exp.DataType.Type.NULL in (type1, type2): 325 return exp.DataType.Type.NULL 326 if exp.DataType.Type.UNKNOWN in (type1, type2): 327 return exp.DataType.Type.UNKNOWN 328 329 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 330 331 def _annotate_binary(self, expression): 332 self._annotate_args(expression) 333 334 left_type = expression.left.type.this 335 right_type = expression.right.type.this 336 337 if isinstance(expression, (exp.And, exp.Or)): 338 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 339 expression.type = exp.DataType.Type.NULL 340 elif exp.DataType.Type.NULL in (left_type, right_type): 341 expression.type = exp.DataType.build( 342 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 343 ) 344 else: 345 expression.type = exp.DataType.Type.BOOLEAN 346 elif isinstance(expression, (exp.Condition, exp.Predicate)): 347 expression.type = exp.DataType.Type.BOOLEAN 348 else: 349 expression.type = self._maybe_coerce(left_type, right_type) 350 351 return expression 352 353 def _annotate_unary(self, expression): 354 self._annotate_args(expression) 355 356 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 357 expression.type = exp.DataType.Type.BOOLEAN 358 else: 359 expression.type = expression.this.type 360 361 return expression 362 363 def _annotate_literal(self, expression): 364 if expression.is_string: 365 expression.type = exp.DataType.Type.VARCHAR 366 elif expression.is_int: 367 expression.type = exp.DataType.Type.INT 368 else: 369 expression.type = exp.DataType.Type.DOUBLE 370 371 return expression 372 373 def _annotate_with_type(self, expression, target_type): 374 expression.type = target_type 375 return self._annotate_args(expression) 376 377 def _annotate_by_args(self, expression, *args, promote=False): 378 self._annotate_args(expression) 379 expressions = [] 380 for arg in args: 381 arg_expr = expression.args.get(arg) 382 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 383 384 last_datatype = None 385 for expr in expressions: 386 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 387 388 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 389 390 if promote: 391 if expression.type.this in exp.DataType.INTEGER_TYPES: 392 expression.type = exp.DataType.Type.BIGINT 393 elif expression.type.this in exp.DataType.FLOAT_TYPES: 394 expression.type = exp.DataType.Type.DOUBLE 395 396 return expression
def
annotate_types(expression, schema=None, annotators=None, coerces_to=None):
8def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 9 """ 10 Recursively infer & annotate types in an expression syntax tree against a schema. 11 Assumes that we've already executed the optimizer's qualify_columns step. 12 13 Example: 14 >>> import sqlglot 15 >>> schema = {"y": {"cola": "SMALLINT"}} 16 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 17 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 18 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 19 <Type.DOUBLE: 'DOUBLE'> 20 21 Args: 22 expression (sqlglot.Expression): Expression to annotate. 23 schema (dict|sqlglot.optimizer.Schema): Database schema. 24 annotators (dict): Maps expression type to corresponding annotation function. 25 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 26 Returns: 27 sqlglot.Expression: expression annotated with types 28 """ 29 30 schema = ensure_schema(schema) 31 32 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
class
TypeAnnotator:
35class TypeAnnotator: 36 ANNOTATORS = { 37 **{ 38 expr_type: lambda self, expr: self._annotate_unary(expr) 39 for expr_type in subclasses(exp.__name__, exp.Unary) 40 }, 41 **{ 42 expr_type: lambda self, expr: self._annotate_binary(expr) 43 for expr_type in subclasses(exp.__name__, exp.Binary) 44 }, 45 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 47 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 48 exp.Alias: lambda self, expr: self._annotate_unary(expr), 49 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 51 exp.Literal: lambda self, expr: self._annotate_literal(expr), 52 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 53 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 54 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 55 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 56 expr, exp.DataType.Type.BIGINT 57 ), 58 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 59 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), 60 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), 61 exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), 62 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 63 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 64 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 65 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 66 expr, exp.DataType.Type.DATETIME 67 ), 68 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 69 expr, exp.DataType.Type.TIMESTAMP 70 ), 71 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 72 expr, exp.DataType.Type.TIMESTAMP 73 ), 74 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 75 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 76 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 77 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 78 expr, exp.DataType.Type.DATETIME 79 ), 80 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 81 expr, exp.DataType.Type.DATETIME 82 ), 83 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 84 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 85 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 86 expr, exp.DataType.Type.TIMESTAMP 87 ), 88 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 89 expr, exp.DataType.Type.TIMESTAMP 90 ), 91 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 92 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 93 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 94 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 95 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 96 expr, exp.DataType.Type.DATE 97 ), 98 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 99 expr, exp.DataType.Type.VARCHAR 100 ), 101 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 102 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 103 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 104 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 105 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 106 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 107 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 108 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 109 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 110 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 111 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 112 expr, exp.DataType.Type.VARCHAR 113 ), 114 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 115 expr, exp.DataType.Type.VARCHAR 116 ), 117 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 118 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 119 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 120 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 121 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 122 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 123 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 124 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 125 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 126 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 127 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 128 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 129 expr, exp.DataType.Type.DOUBLE 130 ), 131 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 132 expr, exp.DataType.Type.BOOLEAN 133 ), 134 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 135 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 136 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 137 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 138 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 139 exp.StrToTime: lambda self, expr: self._annotate_with_type( 140 expr, exp.DataType.Type.TIMESTAMP 141 ), 142 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 143 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 144 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 145 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 146 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 147 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 148 expr, exp.DataType.Type.VARCHAR 149 ), 150 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 151 expr, exp.DataType.Type.DATE 152 ), 153 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 154 expr, exp.DataType.Type.TIMESTAMP 155 ), 156 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 157 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 158 expr, exp.DataType.Type.VARCHAR 159 ), 160 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 161 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 162 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 163 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 164 expr, exp.DataType.Type.TIMESTAMP 165 ), 166 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 167 expr, exp.DataType.Type.VARCHAR 168 ), 169 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 170 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 171 exp.VariancePop: lambda self, expr: self._annotate_with_type( 172 expr, exp.DataType.Type.DOUBLE 173 ), 174 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 175 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 176 } 177 178 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 179 COERCES_TO = { 180 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 181 exp.DataType.Type.TEXT: set(), 182 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 183 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 184 exp.DataType.Type.NCHAR: { 185 exp.DataType.Type.VARCHAR, 186 exp.DataType.Type.NVARCHAR, 187 exp.DataType.Type.TEXT, 188 }, 189 exp.DataType.Type.CHAR: { 190 exp.DataType.Type.NCHAR, 191 exp.DataType.Type.VARCHAR, 192 exp.DataType.Type.NVARCHAR, 193 exp.DataType.Type.TEXT, 194 }, 195 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 196 exp.DataType.Type.DOUBLE: set(), 197 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 198 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 199 exp.DataType.Type.BIGINT: { 200 exp.DataType.Type.DECIMAL, 201 exp.DataType.Type.FLOAT, 202 exp.DataType.Type.DOUBLE, 203 }, 204 exp.DataType.Type.INT: { 205 exp.DataType.Type.BIGINT, 206 exp.DataType.Type.DECIMAL, 207 exp.DataType.Type.FLOAT, 208 exp.DataType.Type.DOUBLE, 209 }, 210 exp.DataType.Type.SMALLINT: { 211 exp.DataType.Type.INT, 212 exp.DataType.Type.BIGINT, 213 exp.DataType.Type.DECIMAL, 214 exp.DataType.Type.FLOAT, 215 exp.DataType.Type.DOUBLE, 216 }, 217 exp.DataType.Type.TINYINT: { 218 exp.DataType.Type.SMALLINT, 219 exp.DataType.Type.INT, 220 exp.DataType.Type.BIGINT, 221 exp.DataType.Type.DECIMAL, 222 exp.DataType.Type.FLOAT, 223 exp.DataType.Type.DOUBLE, 224 }, 225 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 226 exp.DataType.Type.TIMESTAMPLTZ: set(), 227 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 228 exp.DataType.Type.TIMESTAMP: { 229 exp.DataType.Type.TIMESTAMPTZ, 230 exp.DataType.Type.TIMESTAMPLTZ, 231 }, 232 exp.DataType.Type.DATETIME: { 233 exp.DataType.Type.TIMESTAMP, 234 exp.DataType.Type.TIMESTAMPTZ, 235 exp.DataType.Type.TIMESTAMPLTZ, 236 }, 237 exp.DataType.Type.DATE: { 238 exp.DataType.Type.DATETIME, 239 exp.DataType.Type.TIMESTAMP, 240 exp.DataType.Type.TIMESTAMPTZ, 241 exp.DataType.Type.TIMESTAMPLTZ, 242 }, 243 } 244 245 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 246 247 def __init__(self, schema=None, annotators=None, coerces_to=None): 248 self.schema = schema 249 self.annotators = annotators or self.ANNOTATORS 250 self.coerces_to = coerces_to or self.COERCES_TO 251 252 def annotate(self, expression): 253 if isinstance(expression, self.TRAVERSABLES): 254 for scope in traverse_scope(expression): 255 selects = {} 256 for name, source in scope.sources.items(): 257 if not isinstance(source, Scope): 258 continue 259 if isinstance(source.expression, exp.UDTF): 260 values = [] 261 262 if isinstance(source.expression, exp.Lateral): 263 if isinstance(source.expression.this, exp.Explode): 264 values = [source.expression.this.this] 265 else: 266 values = source.expression.expressions[0].expressions 267 268 if not values: 269 continue 270 271 selects[name] = { 272 alias: column 273 for alias, column in zip( 274 source.expression.alias_column_names, 275 values, 276 ) 277 } 278 else: 279 selects[name] = { 280 select.alias_or_name: select for select in source.expression.selects 281 } 282 # First annotate the current scope's column references 283 for col in scope.columns: 284 if not col.table: 285 continue 286 287 source = scope.sources.get(col.table) 288 if isinstance(source, exp.Table): 289 col.type = self.schema.get_column_type(source, col) 290 elif source and col.table in selects: 291 col.type = selects[col.table][col.name].type 292 # Then (possibly) annotate the remaining expressions in the scope 293 self._maybe_annotate(scope.expression) 294 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 295 296 def _maybe_annotate(self, expression): 297 if not isinstance(expression, exp.Expression): 298 return None 299 300 if expression.type: 301 return expression # We've already inferred the expression's type 302 303 annotator = self.annotators.get(expression.__class__) 304 305 return ( 306 annotator(self, expression) 307 if annotator 308 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 309 ) 310 311 def _annotate_args(self, expression): 312 for value in expression.args.values(): 313 for v in ensure_collection(value): 314 self._maybe_annotate(v) 315 316 return expression 317 318 def _maybe_coerce(self, type1, type2): 319 # We propagate the NULL / UNKNOWN types upwards if found 320 if isinstance(type1, exp.DataType): 321 type1 = type1.this 322 if isinstance(type2, exp.DataType): 323 type2 = type2.this 324 325 if exp.DataType.Type.NULL in (type1, type2): 326 return exp.DataType.Type.NULL 327 if exp.DataType.Type.UNKNOWN in (type1, type2): 328 return exp.DataType.Type.UNKNOWN 329 330 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 331 332 def _annotate_binary(self, expression): 333 self._annotate_args(expression) 334 335 left_type = expression.left.type.this 336 right_type = expression.right.type.this 337 338 if isinstance(expression, (exp.And, exp.Or)): 339 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 340 expression.type = exp.DataType.Type.NULL 341 elif exp.DataType.Type.NULL in (left_type, right_type): 342 expression.type = exp.DataType.build( 343 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 344 ) 345 else: 346 expression.type = exp.DataType.Type.BOOLEAN 347 elif isinstance(expression, (exp.Condition, exp.Predicate)): 348 expression.type = exp.DataType.Type.BOOLEAN 349 else: 350 expression.type = self._maybe_coerce(left_type, right_type) 351 352 return expression 353 354 def _annotate_unary(self, expression): 355 self._annotate_args(expression) 356 357 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 358 expression.type = exp.DataType.Type.BOOLEAN 359 else: 360 expression.type = expression.this.type 361 362 return expression 363 364 def _annotate_literal(self, expression): 365 if expression.is_string: 366 expression.type = exp.DataType.Type.VARCHAR 367 elif expression.is_int: 368 expression.type = exp.DataType.Type.INT 369 else: 370 expression.type = exp.DataType.Type.DOUBLE 371 372 return expression 373 374 def _annotate_with_type(self, expression, target_type): 375 expression.type = target_type 376 return self._annotate_args(expression) 377 378 def _annotate_by_args(self, expression, *args, promote=False): 379 self._annotate_args(expression) 380 expressions = [] 381 for arg in args: 382 arg_expr = expression.args.get(arg) 383 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 384 385 last_datatype = None 386 for expr in expressions: 387 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 388 389 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 390 391 if promote: 392 if expression.type.this in exp.DataType.INTEGER_TYPES: 393 expression.type = exp.DataType.Type.BIGINT 394 elif expression.type.this in exp.DataType.FLOAT_TYPES: 395 expression.type = exp.DataType.Type.DOUBLE 396 397 return expression
def
annotate(self, expression):
252 def annotate(self, expression): 253 if isinstance(expression, self.TRAVERSABLES): 254 for scope in traverse_scope(expression): 255 selects = {} 256 for name, source in scope.sources.items(): 257 if not isinstance(source, Scope): 258 continue 259 if isinstance(source.expression, exp.UDTF): 260 values = [] 261 262 if isinstance(source.expression, exp.Lateral): 263 if isinstance(source.expression.this, exp.Explode): 264 values = [source.expression.this.this] 265 else: 266 values = source.expression.expressions[0].expressions 267 268 if not values: 269 continue 270 271 selects[name] = { 272 alias: column 273 for alias, column in zip( 274 source.expression.alias_column_names, 275 values, 276 ) 277 } 278 else: 279 selects[name] = { 280 select.alias_or_name: select for select in source.expression.selects 281 } 282 # First annotate the current scope's column references 283 for col in scope.columns: 284 if not col.table: 285 continue 286 287 source = scope.sources.get(col.table) 288 if isinstance(source, exp.Table): 289 col.type = self.schema.get_column_type(source, col) 290 elif source and col.table in selects: 291 col.type = selects[col.table][col.name].type 292 # Then (possibly) annotate the remaining expressions in the scope 293 self._maybe_annotate(scope.expression) 294 return self._maybe_annotate(expression) # This takes care of non-traversable expressions