sqlglot.optimizer.annotate_types
1from sqlglot import exp 2from sqlglot.helper import ensure_collection, ensure_list, subclasses 3from sqlglot.optimizer.scope import Scope, traverse_scope 4from sqlglot.schema import ensure_schema 5 6 7def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 8 """ 9 Recursively infer & annotate types in an expression syntax tree against a schema. 10 Assumes that we've already executed the optimizer's qualify_columns step. 11 12 Example: 13 >>> import sqlglot 14 >>> schema = {"y": {"cola": "SMALLINT"}} 15 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 16 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 17 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 18 <Type.DOUBLE: 'DOUBLE'> 19 20 Args: 21 expression (sqlglot.Expression): Expression to annotate. 22 schema (dict|sqlglot.optimizer.Schema): Database schema. 23 annotators (dict): Maps expression type to corresponding annotation function. 24 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 25 Returns: 26 sqlglot.Expression: expression annotated with types 27 """ 28 29 schema = ensure_schema(schema) 30 31 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 32 33 34class TypeAnnotator: 35 ANNOTATORS = { 36 **{ 37 expr_type: lambda self, expr: self._annotate_unary(expr) 38 for expr_type in subclasses(exp.__name__, exp.Unary) 39 }, 40 **{ 41 expr_type: lambda self, expr: self._annotate_binary(expr) 42 for expr_type in subclasses(exp.__name__, exp.Binary) 43 }, 44 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 45 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 47 exp.Alias: lambda self, expr: self._annotate_unary(expr), 48 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 49 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.Literal: lambda self, expr: self._annotate_literal(expr), 51 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 52 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 53 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 54 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 55 expr, exp.DataType.Type.BIGINT 56 ), 57 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 58 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 59 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Sum: lambda self, expr: self._annotate_by_args( 61 expr, "this", "expressions", promote=True 62 ), 63 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 64 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 65 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 66 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 67 expr, exp.DataType.Type.DATETIME 68 ), 69 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 70 expr, exp.DataType.Type.TIMESTAMP 71 ), 72 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 73 expr, exp.DataType.Type.TIMESTAMP 74 ), 75 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 76 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 78 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 79 expr, exp.DataType.Type.DATETIME 80 ), 81 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 82 expr, exp.DataType.Type.DATETIME 83 ), 84 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 85 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 87 expr, exp.DataType.Type.TIMESTAMP 88 ), 89 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 90 expr, exp.DataType.Type.TIMESTAMP 91 ), 92 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 93 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 94 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 96 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 97 expr, exp.DataType.Type.DATE 98 ), 99 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 100 expr, exp.DataType.Type.VARCHAR 101 ), 102 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 103 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 104 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 105 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 106 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 107 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 108 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 109 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 110 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 111 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 112 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 113 expr, exp.DataType.Type.VARCHAR 114 ), 115 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 116 expr, exp.DataType.Type.VARCHAR 117 ), 118 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 119 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 120 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 121 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 122 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 123 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 124 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 125 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 126 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 127 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 128 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 129 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 130 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 131 expr, exp.DataType.Type.DOUBLE 132 ), 133 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 134 expr, exp.DataType.Type.BOOLEAN 135 ), 136 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 137 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 138 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 139 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 140 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 141 exp.StrToTime: lambda self, expr: self._annotate_with_type( 142 expr, exp.DataType.Type.TIMESTAMP 143 ), 144 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 145 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 146 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 147 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 148 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 149 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 150 expr, exp.DataType.Type.VARCHAR 151 ), 152 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 153 expr, exp.DataType.Type.DATE 154 ), 155 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 156 expr, exp.DataType.Type.TIMESTAMP 157 ), 158 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 159 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 160 expr, exp.DataType.Type.VARCHAR 161 ), 162 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 163 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 164 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 165 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 166 expr, exp.DataType.Type.TIMESTAMP 167 ), 168 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 169 expr, exp.DataType.Type.VARCHAR 170 ), 171 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 172 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 173 exp.VariancePop: lambda self, expr: self._annotate_with_type( 174 expr, exp.DataType.Type.DOUBLE 175 ), 176 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 177 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 178 } 179 180 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 181 COERCES_TO = { 182 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 183 exp.DataType.Type.TEXT: set(), 184 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 185 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 186 exp.DataType.Type.NCHAR: { 187 exp.DataType.Type.VARCHAR, 188 exp.DataType.Type.NVARCHAR, 189 exp.DataType.Type.TEXT, 190 }, 191 exp.DataType.Type.CHAR: { 192 exp.DataType.Type.NCHAR, 193 exp.DataType.Type.VARCHAR, 194 exp.DataType.Type.NVARCHAR, 195 exp.DataType.Type.TEXT, 196 }, 197 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 198 exp.DataType.Type.DOUBLE: set(), 199 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 200 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 201 exp.DataType.Type.BIGINT: { 202 exp.DataType.Type.DECIMAL, 203 exp.DataType.Type.FLOAT, 204 exp.DataType.Type.DOUBLE, 205 }, 206 exp.DataType.Type.INT: { 207 exp.DataType.Type.BIGINT, 208 exp.DataType.Type.DECIMAL, 209 exp.DataType.Type.FLOAT, 210 exp.DataType.Type.DOUBLE, 211 }, 212 exp.DataType.Type.SMALLINT: { 213 exp.DataType.Type.INT, 214 exp.DataType.Type.BIGINT, 215 exp.DataType.Type.DECIMAL, 216 exp.DataType.Type.FLOAT, 217 exp.DataType.Type.DOUBLE, 218 }, 219 exp.DataType.Type.TINYINT: { 220 exp.DataType.Type.SMALLINT, 221 exp.DataType.Type.INT, 222 exp.DataType.Type.BIGINT, 223 exp.DataType.Type.DECIMAL, 224 exp.DataType.Type.FLOAT, 225 exp.DataType.Type.DOUBLE, 226 }, 227 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 228 exp.DataType.Type.TIMESTAMPLTZ: set(), 229 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 230 exp.DataType.Type.TIMESTAMP: { 231 exp.DataType.Type.TIMESTAMPTZ, 232 exp.DataType.Type.TIMESTAMPLTZ, 233 }, 234 exp.DataType.Type.DATETIME: { 235 exp.DataType.Type.TIMESTAMP, 236 exp.DataType.Type.TIMESTAMPTZ, 237 exp.DataType.Type.TIMESTAMPLTZ, 238 }, 239 exp.DataType.Type.DATE: { 240 exp.DataType.Type.DATETIME, 241 exp.DataType.Type.TIMESTAMP, 242 exp.DataType.Type.TIMESTAMPTZ, 243 exp.DataType.Type.TIMESTAMPLTZ, 244 }, 245 } 246 247 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 248 249 def __init__(self, schema=None, annotators=None, coerces_to=None): 250 self.schema = schema 251 self.annotators = annotators or self.ANNOTATORS 252 self.coerces_to = coerces_to or self.COERCES_TO 253 254 def annotate(self, expression): 255 if isinstance(expression, self.TRAVERSABLES): 256 for scope in traverse_scope(expression): 257 selects = {} 258 for name, source in scope.sources.items(): 259 if not isinstance(source, Scope): 260 continue 261 if isinstance(source.expression, exp.UDTF): 262 values = [] 263 264 if isinstance(source.expression, exp.Lateral): 265 if isinstance(source.expression.this, exp.Explode): 266 values = [source.expression.this.this] 267 else: 268 values = source.expression.expressions[0].expressions 269 270 if not values: 271 continue 272 273 selects[name] = { 274 alias: column 275 for alias, column in zip( 276 source.expression.alias_column_names, 277 values, 278 ) 279 } 280 else: 281 selects[name] = { 282 select.alias_or_name: select for select in source.expression.selects 283 } 284 # First annotate the current scope's column references 285 for col in scope.columns: 286 if not col.table: 287 continue 288 289 source = scope.sources.get(col.table) 290 if isinstance(source, exp.Table): 291 col.type = self.schema.get_column_type(source, col) 292 elif source and col.table in selects and col.name in selects[col.table]: 293 col.type = selects[col.table][col.name].type 294 # Then (possibly) annotate the remaining expressions in the scope 295 self._maybe_annotate(scope.expression) 296 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 297 298 def _maybe_annotate(self, expression): 299 if not isinstance(expression, exp.Expression): 300 return None 301 302 if expression.type: 303 return expression # We've already inferred the expression's type 304 305 annotator = self.annotators.get(expression.__class__) 306 307 return ( 308 annotator(self, expression) 309 if annotator 310 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 311 ) 312 313 def _annotate_args(self, expression): 314 for value in expression.args.values(): 315 for v in ensure_collection(value): 316 self._maybe_annotate(v) 317 318 return expression 319 320 def _maybe_coerce(self, type1, type2): 321 # We propagate the NULL / UNKNOWN types upwards if found 322 if isinstance(type1, exp.DataType): 323 type1 = type1.this 324 if isinstance(type2, exp.DataType): 325 type2 = type2.this 326 327 if exp.DataType.Type.NULL in (type1, type2): 328 return exp.DataType.Type.NULL 329 if exp.DataType.Type.UNKNOWN in (type1, type2): 330 return exp.DataType.Type.UNKNOWN 331 332 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 333 334 def _annotate_binary(self, expression): 335 self._annotate_args(expression) 336 337 left_type = expression.left.type.this 338 right_type = expression.right.type.this 339 340 if isinstance(expression, (exp.And, exp.Or)): 341 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 342 expression.type = exp.DataType.Type.NULL 343 elif exp.DataType.Type.NULL in (left_type, right_type): 344 expression.type = exp.DataType.build( 345 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 346 ) 347 else: 348 expression.type = exp.DataType.Type.BOOLEAN 349 elif isinstance(expression, (exp.Condition, exp.Predicate)): 350 expression.type = exp.DataType.Type.BOOLEAN 351 else: 352 expression.type = self._maybe_coerce(left_type, right_type) 353 354 return expression 355 356 def _annotate_unary(self, expression): 357 self._annotate_args(expression) 358 359 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 360 expression.type = exp.DataType.Type.BOOLEAN 361 else: 362 expression.type = expression.this.type 363 364 return expression 365 366 def _annotate_literal(self, expression): 367 if expression.is_string: 368 expression.type = exp.DataType.Type.VARCHAR 369 elif expression.is_int: 370 expression.type = exp.DataType.Type.INT 371 else: 372 expression.type = exp.DataType.Type.DOUBLE 373 374 return expression 375 376 def _annotate_with_type(self, expression, target_type): 377 expression.type = target_type 378 return self._annotate_args(expression) 379 380 def _annotate_by_args(self, expression, *args, promote=False): 381 self._annotate_args(expression) 382 expressions = [] 383 for arg in args: 384 arg_expr = expression.args.get(arg) 385 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 386 387 last_datatype = None 388 for expr in expressions: 389 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 390 391 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 392 393 if promote: 394 if expression.type.this in exp.DataType.INTEGER_TYPES: 395 expression.type = exp.DataType.Type.BIGINT 396 elif expression.type.this in exp.DataType.FLOAT_TYPES: 397 expression.type = exp.DataType.Type.DOUBLE 398 399 return expression
def
annotate_types(expression, schema=None, annotators=None, coerces_to=None):
8def annotate_types(expression, schema=None, annotators=None, coerces_to=None): 9 """ 10 Recursively infer & annotate types in an expression syntax tree against a schema. 11 Assumes that we've already executed the optimizer's qualify_columns step. 12 13 Example: 14 >>> import sqlglot 15 >>> schema = {"y": {"cola": "SMALLINT"}} 16 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 17 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 18 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 19 <Type.DOUBLE: 'DOUBLE'> 20 21 Args: 22 expression (sqlglot.Expression): Expression to annotate. 23 schema (dict|sqlglot.optimizer.Schema): Database schema. 24 annotators (dict): Maps expression type to corresponding annotation function. 25 coerces_to (dict): Maps expression type to set of types that it can be coerced into. 26 Returns: 27 sqlglot.Expression: expression annotated with types 28 """ 29 30 schema = ensure_schema(schema) 31 32 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:
sqlglot.Expression: expression annotated with types
class
TypeAnnotator:
35class TypeAnnotator: 36 ANNOTATORS = { 37 **{ 38 expr_type: lambda self, expr: self._annotate_unary(expr) 39 for expr_type in subclasses(exp.__name__, exp.Unary) 40 }, 41 **{ 42 expr_type: lambda self, expr: self._annotate_binary(expr) 43 for expr_type in subclasses(exp.__name__, exp.Binary) 44 }, 45 exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 46 exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), 47 exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), 48 exp.Alias: lambda self, expr: self._annotate_unary(expr), 49 exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 50 exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 51 exp.Literal: lambda self, expr: self._annotate_literal(expr), 52 exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), 53 exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), 54 exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), 55 exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( 56 expr, exp.DataType.Type.BIGINT 57 ), 58 exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 59 exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 60 exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 61 exp.Sum: lambda self, expr: self._annotate_by_args( 62 expr, "this", "expressions", promote=True 63 ), 64 exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 65 exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 66 exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 67 exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( 68 expr, exp.DataType.Type.DATETIME 69 ), 70 exp.CurrentTime: lambda self, expr: self._annotate_with_type( 71 expr, exp.DataType.Type.TIMESTAMP 72 ), 73 exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( 74 expr, exp.DataType.Type.TIMESTAMP 75 ), 76 exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 77 exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 78 exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 79 exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( 80 expr, exp.DataType.Type.DATETIME 81 ), 82 exp.DatetimeSub: lambda self, expr: self._annotate_with_type( 83 expr, exp.DataType.Type.DATETIME 84 ), 85 exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 86 exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 87 exp.TimestampAdd: lambda self, expr: self._annotate_with_type( 88 expr, exp.DataType.Type.TIMESTAMP 89 ), 90 exp.TimestampSub: lambda self, expr: self._annotate_with_type( 91 expr, exp.DataType.Type.TIMESTAMP 92 ), 93 exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 94 exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 95 exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), 96 exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 97 exp.DateStrToDate: lambda self, expr: self._annotate_with_type( 98 expr, exp.DataType.Type.DATE 99 ), 100 exp.DateToDateStr: lambda self, expr: self._annotate_with_type( 101 expr, exp.DataType.Type.VARCHAR 102 ), 103 exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 104 exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 105 exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 106 exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 107 exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 108 exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), 109 exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), 110 exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), 111 exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), 112 exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 113 exp.GroupConcat: lambda self, expr: self._annotate_with_type( 114 expr, exp.DataType.Type.VARCHAR 115 ), 116 exp.ArrayConcat: lambda self, expr: self._annotate_with_type( 117 expr, exp.DataType.Type.VARCHAR 118 ), 119 exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 120 exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), 121 exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), 122 exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 123 exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 124 exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 125 exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 126 exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 127 exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 128 exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 129 exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 130 exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 131 exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( 132 expr, exp.DataType.Type.DOUBLE 133 ), 134 exp.RegexpLike: lambda self, expr: self._annotate_with_type( 135 expr, exp.DataType.Type.BOOLEAN 136 ), 137 exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 138 exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 139 exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 140 exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 141 exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 142 exp.StrToTime: lambda self, expr: self._annotate_with_type( 143 expr, exp.DataType.Type.TIMESTAMP 144 ), 145 exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 146 exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 147 exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 148 exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 149 exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 150 exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( 151 expr, exp.DataType.Type.VARCHAR 152 ), 153 exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( 154 expr, exp.DataType.Type.DATE 155 ), 156 exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( 157 expr, exp.DataType.Type.TIMESTAMP 158 ), 159 exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 160 exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( 161 expr, exp.DataType.Type.VARCHAR 162 ), 163 exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), 164 exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), 165 exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 166 exp.UnixToTime: lambda self, expr: self._annotate_with_type( 167 expr, exp.DataType.Type.TIMESTAMP 168 ), 169 exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( 170 expr, exp.DataType.Type.VARCHAR 171 ), 172 exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), 173 exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), 174 exp.VariancePop: lambda self, expr: self._annotate_with_type( 175 expr, exp.DataType.Type.DOUBLE 176 ), 177 exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 178 exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), 179 } 180 181 # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 182 COERCES_TO = { 183 # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT 184 exp.DataType.Type.TEXT: set(), 185 exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, 186 exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, 187 exp.DataType.Type.NCHAR: { 188 exp.DataType.Type.VARCHAR, 189 exp.DataType.Type.NVARCHAR, 190 exp.DataType.Type.TEXT, 191 }, 192 exp.DataType.Type.CHAR: { 193 exp.DataType.Type.NCHAR, 194 exp.DataType.Type.VARCHAR, 195 exp.DataType.Type.NVARCHAR, 196 exp.DataType.Type.TEXT, 197 }, 198 # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE 199 exp.DataType.Type.DOUBLE: set(), 200 exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, 201 exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, 202 exp.DataType.Type.BIGINT: { 203 exp.DataType.Type.DECIMAL, 204 exp.DataType.Type.FLOAT, 205 exp.DataType.Type.DOUBLE, 206 }, 207 exp.DataType.Type.INT: { 208 exp.DataType.Type.BIGINT, 209 exp.DataType.Type.DECIMAL, 210 exp.DataType.Type.FLOAT, 211 exp.DataType.Type.DOUBLE, 212 }, 213 exp.DataType.Type.SMALLINT: { 214 exp.DataType.Type.INT, 215 exp.DataType.Type.BIGINT, 216 exp.DataType.Type.DECIMAL, 217 exp.DataType.Type.FLOAT, 218 exp.DataType.Type.DOUBLE, 219 }, 220 exp.DataType.Type.TINYINT: { 221 exp.DataType.Type.SMALLINT, 222 exp.DataType.Type.INT, 223 exp.DataType.Type.BIGINT, 224 exp.DataType.Type.DECIMAL, 225 exp.DataType.Type.FLOAT, 226 exp.DataType.Type.DOUBLE, 227 }, 228 # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ 229 exp.DataType.Type.TIMESTAMPLTZ: set(), 230 exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, 231 exp.DataType.Type.TIMESTAMP: { 232 exp.DataType.Type.TIMESTAMPTZ, 233 exp.DataType.Type.TIMESTAMPLTZ, 234 }, 235 exp.DataType.Type.DATETIME: { 236 exp.DataType.Type.TIMESTAMP, 237 exp.DataType.Type.TIMESTAMPTZ, 238 exp.DataType.Type.TIMESTAMPLTZ, 239 }, 240 exp.DataType.Type.DATE: { 241 exp.DataType.Type.DATETIME, 242 exp.DataType.Type.TIMESTAMP, 243 exp.DataType.Type.TIMESTAMPTZ, 244 exp.DataType.Type.TIMESTAMPLTZ, 245 }, 246 } 247 248 TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) 249 250 def __init__(self, schema=None, annotators=None, coerces_to=None): 251 self.schema = schema 252 self.annotators = annotators or self.ANNOTATORS 253 self.coerces_to = coerces_to or self.COERCES_TO 254 255 def annotate(self, expression): 256 if isinstance(expression, self.TRAVERSABLES): 257 for scope in traverse_scope(expression): 258 selects = {} 259 for name, source in scope.sources.items(): 260 if not isinstance(source, Scope): 261 continue 262 if isinstance(source.expression, exp.UDTF): 263 values = [] 264 265 if isinstance(source.expression, exp.Lateral): 266 if isinstance(source.expression.this, exp.Explode): 267 values = [source.expression.this.this] 268 else: 269 values = source.expression.expressions[0].expressions 270 271 if not values: 272 continue 273 274 selects[name] = { 275 alias: column 276 for alias, column in zip( 277 source.expression.alias_column_names, 278 values, 279 ) 280 } 281 else: 282 selects[name] = { 283 select.alias_or_name: select for select in source.expression.selects 284 } 285 # First annotate the current scope's column references 286 for col in scope.columns: 287 if not col.table: 288 continue 289 290 source = scope.sources.get(col.table) 291 if isinstance(source, exp.Table): 292 col.type = self.schema.get_column_type(source, col) 293 elif source and col.table in selects and col.name in selects[col.table]: 294 col.type = selects[col.table][col.name].type 295 # Then (possibly) annotate the remaining expressions in the scope 296 self._maybe_annotate(scope.expression) 297 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 298 299 def _maybe_annotate(self, expression): 300 if not isinstance(expression, exp.Expression): 301 return None 302 303 if expression.type: 304 return expression # We've already inferred the expression's type 305 306 annotator = self.annotators.get(expression.__class__) 307 308 return ( 309 annotator(self, expression) 310 if annotator 311 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 312 ) 313 314 def _annotate_args(self, expression): 315 for value in expression.args.values(): 316 for v in ensure_collection(value): 317 self._maybe_annotate(v) 318 319 return expression 320 321 def _maybe_coerce(self, type1, type2): 322 # We propagate the NULL / UNKNOWN types upwards if found 323 if isinstance(type1, exp.DataType): 324 type1 = type1.this 325 if isinstance(type2, exp.DataType): 326 type2 = type2.this 327 328 if exp.DataType.Type.NULL in (type1, type2): 329 return exp.DataType.Type.NULL 330 if exp.DataType.Type.UNKNOWN in (type1, type2): 331 return exp.DataType.Type.UNKNOWN 332 333 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 334 335 def _annotate_binary(self, expression): 336 self._annotate_args(expression) 337 338 left_type = expression.left.type.this 339 right_type = expression.right.type.this 340 341 if isinstance(expression, (exp.And, exp.Or)): 342 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 343 expression.type = exp.DataType.Type.NULL 344 elif exp.DataType.Type.NULL in (left_type, right_type): 345 expression.type = exp.DataType.build( 346 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 347 ) 348 else: 349 expression.type = exp.DataType.Type.BOOLEAN 350 elif isinstance(expression, (exp.Condition, exp.Predicate)): 351 expression.type = exp.DataType.Type.BOOLEAN 352 else: 353 expression.type = self._maybe_coerce(left_type, right_type) 354 355 return expression 356 357 def _annotate_unary(self, expression): 358 self._annotate_args(expression) 359 360 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 361 expression.type = exp.DataType.Type.BOOLEAN 362 else: 363 expression.type = expression.this.type 364 365 return expression 366 367 def _annotate_literal(self, expression): 368 if expression.is_string: 369 expression.type = exp.DataType.Type.VARCHAR 370 elif expression.is_int: 371 expression.type = exp.DataType.Type.INT 372 else: 373 expression.type = exp.DataType.Type.DOUBLE 374 375 return expression 376 377 def _annotate_with_type(self, expression, target_type): 378 expression.type = target_type 379 return self._annotate_args(expression) 380 381 def _annotate_by_args(self, expression, *args, promote=False): 382 self._annotate_args(expression) 383 expressions = [] 384 for arg in args: 385 arg_expr = expression.args.get(arg) 386 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 387 388 last_datatype = None 389 for expr in expressions: 390 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 391 392 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 393 394 if promote: 395 if expression.type.this in exp.DataType.INTEGER_TYPES: 396 expression.type = exp.DataType.Type.BIGINT 397 elif expression.type.this in exp.DataType.FLOAT_TYPES: 398 expression.type = exp.DataType.Type.DOUBLE 399 400 return expression
def
annotate(self, expression):
255 def annotate(self, expression): 256 if isinstance(expression, self.TRAVERSABLES): 257 for scope in traverse_scope(expression): 258 selects = {} 259 for name, source in scope.sources.items(): 260 if not isinstance(source, Scope): 261 continue 262 if isinstance(source.expression, exp.UDTF): 263 values = [] 264 265 if isinstance(source.expression, exp.Lateral): 266 if isinstance(source.expression.this, exp.Explode): 267 values = [source.expression.this.this] 268 else: 269 values = source.expression.expressions[0].expressions 270 271 if not values: 272 continue 273 274 selects[name] = { 275 alias: column 276 for alias, column in zip( 277 source.expression.alias_column_names, 278 values, 279 ) 280 } 281 else: 282 selects[name] = { 283 select.alias_or_name: select for select in source.expression.selects 284 } 285 # First annotate the current scope's column references 286 for col in scope.columns: 287 if not col.table: 288 continue 289 290 source = scope.sources.get(col.table) 291 if isinstance(source, exp.Table): 292 col.type = self.schema.get_column_type(source, col) 293 elif source and col.table in selects and col.name in selects[col.table]: 294 col.type = selects[col.table][col.name].type 295 # Then (possibly) annotate the remaining expressions in the scope 296 self._maybe_annotate(scope.expression) 297 return self._maybe_annotate(expression) # This takes care of non-traversable expressions