sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot._typing import E 7from sqlglot.helper import ensure_list, subclasses 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import Schema, ensure_schema 10 11if t.TYPE_CHECKING: 12 B = t.TypeVar("B", bound=exp.Binary) 13 14 15def annotate_types( 16 expression: E, 17 schema: t.Optional[t.Dict | Schema] = None, 18 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 19 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 20) -> E: 21 """ 22 Infers the types of an expression, annotating its AST accordingly. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"y": {"cola": "SMALLINT"}} 27 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 28 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 29 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 30 <Type.DOUBLE: 'DOUBLE'> 31 32 Args: 33 expression: Expression to annotate. 34 schema: Database schema. 35 annotators: Maps expression type to corresponding annotation function. 36 coerces_to: Maps expression type to set of types that it can be coerced into. 37 38 Returns: 39 The expression annotated with types. 40 """ 41 42 schema = ensure_schema(schema) 43 44 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 45 46 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 48 return lambda self, e: self._annotate_with_type(e, data_type) 49 50 51class _TypeAnnotator(type): 52 def __new__(cls, clsname, bases, attrs): 53 klass = super().__new__(cls, clsname, bases, attrs) 54 55 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 56 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 57 text_precedence = ( 58 exp.DataType.Type.TEXT, 59 exp.DataType.Type.NVARCHAR, 60 exp.DataType.Type.VARCHAR, 61 exp.DataType.Type.NCHAR, 62 exp.DataType.Type.CHAR, 63 ) 64 numeric_precedence = ( 65 exp.DataType.Type.DOUBLE, 66 exp.DataType.Type.FLOAT, 67 exp.DataType.Type.DECIMAL, 68 exp.DataType.Type.BIGINT, 69 exp.DataType.Type.INT, 70 exp.DataType.Type.SMALLINT, 71 exp.DataType.Type.TINYINT, 72 ) 73 timelike_precedence = ( 74 exp.DataType.Type.TIMESTAMPLTZ, 75 exp.DataType.Type.TIMESTAMPTZ, 76 exp.DataType.Type.TIMESTAMP, 77 exp.DataType.Type.DATETIME, 78 exp.DataType.Type.DATE, 79 ) 80 81 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 82 coerces_to = set() 83 for data_type in type_precedence: 84 klass.COERCES_TO[data_type] = coerces_to.copy() 85 coerces_to |= {data_type} 86 87 return klass 88 89 90class TypeAnnotator(metaclass=_TypeAnnotator): 91 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 92 exp.DataType.Type.BIGINT: { 93 exp.ApproxDistinct, 94 exp.ArraySize, 95 exp.Count, 96 exp.Length, 97 }, 98 exp.DataType.Type.BOOLEAN: { 99 exp.Between, 100 exp.Boolean, 101 exp.In, 102 exp.RegexpLike, 103 }, 104 exp.DataType.Type.DATE: { 105 exp.CurrentDate, 106 exp.Date, 107 exp.DateAdd, 108 exp.DateFromParts, 109 exp.DateStrToDate, 110 exp.DateSub, 111 exp.DateTrunc, 112 exp.DiToDate, 113 exp.StrToDate, 114 exp.TimeStrToDate, 115 exp.TsOrDsToDate, 116 }, 117 exp.DataType.Type.DATETIME: { 118 exp.CurrentDatetime, 119 exp.DatetimeAdd, 120 exp.DatetimeSub, 121 }, 122 exp.DataType.Type.DOUBLE: { 123 exp.ApproxQuantile, 124 exp.Avg, 125 exp.Exp, 126 exp.Ln, 127 exp.Log, 128 exp.Log2, 129 exp.Log10, 130 exp.Pow, 131 exp.Quantile, 132 exp.Round, 133 exp.SafeDivide, 134 exp.Sqrt, 135 exp.Stddev, 136 exp.StddevPop, 137 exp.StddevSamp, 138 exp.Variance, 139 exp.VariancePop, 140 }, 141 exp.DataType.Type.INT: { 142 exp.Ceil, 143 exp.DateDiff, 144 exp.DatetimeDiff, 145 exp.Extract, 146 exp.TimestampDiff, 147 exp.TimeDiff, 148 exp.DateToDi, 149 exp.Floor, 150 exp.Levenshtein, 151 exp.StrPosition, 152 exp.TsOrDiToDi, 153 }, 154 exp.DataType.Type.TIMESTAMP: { 155 exp.CurrentTime, 156 exp.CurrentTimestamp, 157 exp.StrToTime, 158 exp.TimeAdd, 159 exp.TimeStrToTime, 160 exp.TimeSub, 161 exp.TimestampAdd, 162 exp.TimestampSub, 163 exp.UnixToTime, 164 }, 165 exp.DataType.Type.TINYINT: { 166 exp.Day, 167 exp.Month, 168 exp.Week, 169 exp.Year, 170 }, 171 exp.DataType.Type.VARCHAR: { 172 exp.ArrayConcat, 173 exp.Concat, 174 exp.ConcatWs, 175 exp.DateToDateStr, 176 exp.GroupConcat, 177 exp.Initcap, 178 exp.Lower, 179 exp.SafeConcat, 180 exp.Substring, 181 exp.TimeToStr, 182 exp.TimeToTimeStr, 183 exp.Trim, 184 exp.TsOrDsToDateStr, 185 exp.UnixToStr, 186 exp.UnixToTimeStr, 187 exp.Upper, 188 }, 189 } 190 191 ANNOTATORS: t.Dict = { 192 **{ 193 expr_type: lambda self, e: self._annotate_unary(e) 194 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 195 }, 196 **{ 197 expr_type: lambda self, e: self._annotate_binary(e) 198 for expr_type in subclasses(exp.__name__, exp.Binary) 199 }, 200 **{ 201 expr_type: _annotate_with_type_lambda(data_type) 202 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 203 for expr_type in expressions 204 }, 205 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 206 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 207 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 208 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 209 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 210 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 211 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 212 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 213 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 214 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 215 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 216 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 217 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 218 exp.Literal: lambda self, e: self._annotate_literal(e), 219 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 220 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 221 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 222 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 223 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 224 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 225 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 226 } 227 228 NESTED_TYPES = { 229 exp.DataType.Type.ARRAY, 230 } 231 232 # Specifies what types a given type can be coerced into (autofilled) 233 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 234 235 def __init__( 236 self, 237 schema: Schema, 238 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 239 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 240 ) -> None: 241 self.schema = schema 242 self.annotators = annotators or self.ANNOTATORS 243 self.coerces_to = coerces_to or self.COERCES_TO 244 245 def annotate(self, expression: E) -> E: 246 for scope in traverse_scope(expression): 247 selects = {} 248 for name, source in scope.sources.items(): 249 if not isinstance(source, Scope): 250 continue 251 if isinstance(source.expression, exp.UDTF): 252 values = [] 253 254 if isinstance(source.expression, exp.Lateral): 255 if isinstance(source.expression.this, exp.Explode): 256 values = [source.expression.this.this] 257 else: 258 values = source.expression.expressions[0].expressions 259 260 if not values: 261 continue 262 263 selects[name] = { 264 alias: column 265 for alias, column in zip( 266 source.expression.alias_column_names, 267 values, 268 ) 269 } 270 else: 271 selects[name] = { 272 select.alias_or_name: select for select in source.expression.selects 273 } 274 275 # First annotate the current scope's column references 276 for col in scope.columns: 277 if not col.table: 278 continue 279 280 source = scope.sources.get(col.table) 281 if isinstance(source, exp.Table): 282 col.type = self.schema.get_column_type(source, col) 283 elif source and col.table in selects and col.name in selects[col.table]: 284 col.type = selects[col.table][col.name].type 285 286 # Then (possibly) annotate the remaining expressions in the scope 287 self._maybe_annotate(scope.expression) 288 289 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 290 291 def _maybe_annotate(self, expression: E) -> E: 292 if expression.type: 293 return expression # We've already inferred the expression's type 294 295 annotator = self.annotators.get(expression.__class__) 296 297 return ( 298 annotator(self, expression) 299 if annotator 300 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 301 ) 302 303 def _annotate_args(self, expression: E) -> E: 304 for _, value in expression.iter_expressions(): 305 self._maybe_annotate(value) 306 307 return expression 308 309 def _maybe_coerce( 310 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 311 ) -> exp.DataType | exp.DataType.Type: 312 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 313 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 314 315 # We propagate the NULL / UNKNOWN types upwards if found 316 if exp.DataType.Type.NULL in (type1_value, type2_value): 317 return exp.DataType.Type.NULL 318 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 319 return exp.DataType.Type.UNKNOWN 320 321 if type1_value in self.NESTED_TYPES: 322 return type1 323 if type2_value in self.NESTED_TYPES: 324 return type2 325 326 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 327 328 # Note: the following "no_type_check" decorators were added because mypy was yelling due 329 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 330 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 331 332 @t.no_type_check 333 def _annotate_binary(self, expression: B) -> B: 334 self._annotate_args(expression) 335 336 left_type = expression.left.type.this 337 right_type = expression.right.type.this 338 339 if isinstance(expression, exp.Connector): 340 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 341 expression.type = exp.DataType.Type.NULL 342 elif exp.DataType.Type.NULL in (left_type, right_type): 343 expression.type = exp.DataType.build( 344 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 345 ) 346 else: 347 expression.type = exp.DataType.Type.BOOLEAN 348 elif isinstance(expression, exp.Predicate): 349 expression.type = exp.DataType.Type.BOOLEAN 350 else: 351 expression.type = self._maybe_coerce(left_type, right_type) 352 353 return expression 354 355 @t.no_type_check 356 def _annotate_unary(self, expression: E) -> E: 357 self._annotate_args(expression) 358 359 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 360 expression.type = exp.DataType.Type.BOOLEAN 361 else: 362 expression.type = expression.this.type 363 364 return expression 365 366 @t.no_type_check 367 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 368 if expression.is_string: 369 expression.type = exp.DataType.Type.VARCHAR 370 elif expression.is_int: 371 expression.type = exp.DataType.Type.INT 372 else: 373 expression.type = exp.DataType.Type.DOUBLE 374 375 return expression 376 377 @t.no_type_check 378 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 379 expression.type = target_type 380 return self._annotate_args(expression) 381 382 @t.no_type_check 383 def _annotate_by_args( 384 self, expression: E, *args: str, promote: bool = False, array: bool = False 385 ) -> E: 386 self._annotate_args(expression) 387 388 expressions: t.List[exp.Expression] = [] 389 for arg in args: 390 arg_expr = expression.args.get(arg) 391 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 392 393 last_datatype = None 394 for expr in expressions: 395 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 396 397 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 398 399 if promote: 400 if expression.type.this in exp.DataType.INTEGER_TYPES: 401 expression.type = exp.DataType.Type.BIGINT 402 elif expression.type.this in exp.DataType.FLOAT_TYPES: 403 expression.type = exp.DataType.Type.DOUBLE 404 405 if array: 406 expression.type = exp.DataType( 407 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 408 ) 409 410 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types( 17 expression: E, 18 schema: t.Optional[t.Dict | Schema] = None, 19 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 20 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 21) -> E: 22 """ 23 Infers the types of an expression, annotating its AST accordingly. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"y": {"cola": "SMALLINT"}} 28 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 29 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 30 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 31 <Type.DOUBLE: 'DOUBLE'> 32 33 Args: 34 expression: Expression to annotate. 35 schema: Database schema. 36 annotators: Maps expression type to corresponding annotation function. 37 coerces_to: Maps expression type to set of types that it can be coerced into. 38 39 Returns: 40 The expression annotated with types. 41 """ 42 43 schema = ensure_schema(schema) 44 45 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
class
TypeAnnotator:
91class TypeAnnotator(metaclass=_TypeAnnotator): 92 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 93 exp.DataType.Type.BIGINT: { 94 exp.ApproxDistinct, 95 exp.ArraySize, 96 exp.Count, 97 exp.Length, 98 }, 99 exp.DataType.Type.BOOLEAN: { 100 exp.Between, 101 exp.Boolean, 102 exp.In, 103 exp.RegexpLike, 104 }, 105 exp.DataType.Type.DATE: { 106 exp.CurrentDate, 107 exp.Date, 108 exp.DateAdd, 109 exp.DateFromParts, 110 exp.DateStrToDate, 111 exp.DateSub, 112 exp.DateTrunc, 113 exp.DiToDate, 114 exp.StrToDate, 115 exp.TimeStrToDate, 116 exp.TsOrDsToDate, 117 }, 118 exp.DataType.Type.DATETIME: { 119 exp.CurrentDatetime, 120 exp.DatetimeAdd, 121 exp.DatetimeSub, 122 }, 123 exp.DataType.Type.DOUBLE: { 124 exp.ApproxQuantile, 125 exp.Avg, 126 exp.Exp, 127 exp.Ln, 128 exp.Log, 129 exp.Log2, 130 exp.Log10, 131 exp.Pow, 132 exp.Quantile, 133 exp.Round, 134 exp.SafeDivide, 135 exp.Sqrt, 136 exp.Stddev, 137 exp.StddevPop, 138 exp.StddevSamp, 139 exp.Variance, 140 exp.VariancePop, 141 }, 142 exp.DataType.Type.INT: { 143 exp.Ceil, 144 exp.DateDiff, 145 exp.DatetimeDiff, 146 exp.Extract, 147 exp.TimestampDiff, 148 exp.TimeDiff, 149 exp.DateToDi, 150 exp.Floor, 151 exp.Levenshtein, 152 exp.StrPosition, 153 exp.TsOrDiToDi, 154 }, 155 exp.DataType.Type.TIMESTAMP: { 156 exp.CurrentTime, 157 exp.CurrentTimestamp, 158 exp.StrToTime, 159 exp.TimeAdd, 160 exp.TimeStrToTime, 161 exp.TimeSub, 162 exp.TimestampAdd, 163 exp.TimestampSub, 164 exp.UnixToTime, 165 }, 166 exp.DataType.Type.TINYINT: { 167 exp.Day, 168 exp.Month, 169 exp.Week, 170 exp.Year, 171 }, 172 exp.DataType.Type.VARCHAR: { 173 exp.ArrayConcat, 174 exp.Concat, 175 exp.ConcatWs, 176 exp.DateToDateStr, 177 exp.GroupConcat, 178 exp.Initcap, 179 exp.Lower, 180 exp.SafeConcat, 181 exp.Substring, 182 exp.TimeToStr, 183 exp.TimeToTimeStr, 184 exp.Trim, 185 exp.TsOrDsToDateStr, 186 exp.UnixToStr, 187 exp.UnixToTimeStr, 188 exp.Upper, 189 }, 190 } 191 192 ANNOTATORS: t.Dict = { 193 **{ 194 expr_type: lambda self, e: self._annotate_unary(e) 195 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 196 }, 197 **{ 198 expr_type: lambda self, e: self._annotate_binary(e) 199 for expr_type in subclasses(exp.__name__, exp.Binary) 200 }, 201 **{ 202 expr_type: _annotate_with_type_lambda(data_type) 203 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 204 for expr_type in expressions 205 }, 206 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 207 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 208 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 209 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 210 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 211 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 212 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 213 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 214 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 215 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 216 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 217 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 218 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 219 exp.Literal: lambda self, e: self._annotate_literal(e), 220 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 221 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 222 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 223 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 224 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 225 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 226 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 227 } 228 229 NESTED_TYPES = { 230 exp.DataType.Type.ARRAY, 231 } 232 233 # Specifies what types a given type can be coerced into (autofilled) 234 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 235 236 def __init__( 237 self, 238 schema: Schema, 239 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 240 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 241 ) -> None: 242 self.schema = schema 243 self.annotators = annotators or self.ANNOTATORS 244 self.coerces_to = coerces_to or self.COERCES_TO 245 246 def annotate(self, expression: E) -> E: 247 for scope in traverse_scope(expression): 248 selects = {} 249 for name, source in scope.sources.items(): 250 if not isinstance(source, Scope): 251 continue 252 if isinstance(source.expression, exp.UDTF): 253 values = [] 254 255 if isinstance(source.expression, exp.Lateral): 256 if isinstance(source.expression.this, exp.Explode): 257 values = [source.expression.this.this] 258 else: 259 values = source.expression.expressions[0].expressions 260 261 if not values: 262 continue 263 264 selects[name] = { 265 alias: column 266 for alias, column in zip( 267 source.expression.alias_column_names, 268 values, 269 ) 270 } 271 else: 272 selects[name] = { 273 select.alias_or_name: select for select in source.expression.selects 274 } 275 276 # First annotate the current scope's column references 277 for col in scope.columns: 278 if not col.table: 279 continue 280 281 source = scope.sources.get(col.table) 282 if isinstance(source, exp.Table): 283 col.type = self.schema.get_column_type(source, col) 284 elif source and col.table in selects and col.name in selects[col.table]: 285 col.type = selects[col.table][col.name].type 286 287 # Then (possibly) annotate the remaining expressions in the scope 288 self._maybe_annotate(scope.expression) 289 290 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 291 292 def _maybe_annotate(self, expression: E) -> E: 293 if expression.type: 294 return expression # We've already inferred the expression's type 295 296 annotator = self.annotators.get(expression.__class__) 297 298 return ( 299 annotator(self, expression) 300 if annotator 301 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 302 ) 303 304 def _annotate_args(self, expression: E) -> E: 305 for _, value in expression.iter_expressions(): 306 self._maybe_annotate(value) 307 308 return expression 309 310 def _maybe_coerce( 311 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 312 ) -> exp.DataType | exp.DataType.Type: 313 type1_value = type1.this if isinstance(type1, exp.DataType) else type1 314 type2_value = type2.this if isinstance(type2, exp.DataType) else type2 315 316 # We propagate the NULL / UNKNOWN types upwards if found 317 if exp.DataType.Type.NULL in (type1_value, type2_value): 318 return exp.DataType.Type.NULL 319 if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): 320 return exp.DataType.Type.UNKNOWN 321 322 if type1_value in self.NESTED_TYPES: 323 return type1 324 if type2_value in self.NESTED_TYPES: 325 return type2 326 327 return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore 328 329 # Note: the following "no_type_check" decorators were added because mypy was yelling due 330 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 331 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 332 333 @t.no_type_check 334 def _annotate_binary(self, expression: B) -> B: 335 self._annotate_args(expression) 336 337 left_type = expression.left.type.this 338 right_type = expression.right.type.this 339 340 if isinstance(expression, exp.Connector): 341 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 342 expression.type = exp.DataType.Type.NULL 343 elif exp.DataType.Type.NULL in (left_type, right_type): 344 expression.type = exp.DataType.build( 345 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 346 ) 347 else: 348 expression.type = exp.DataType.Type.BOOLEAN 349 elif isinstance(expression, exp.Predicate): 350 expression.type = exp.DataType.Type.BOOLEAN 351 else: 352 expression.type = self._maybe_coerce(left_type, right_type) 353 354 return expression 355 356 @t.no_type_check 357 def _annotate_unary(self, expression: E) -> E: 358 self._annotate_args(expression) 359 360 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 361 expression.type = exp.DataType.Type.BOOLEAN 362 else: 363 expression.type = expression.this.type 364 365 return expression 366 367 @t.no_type_check 368 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 369 if expression.is_string: 370 expression.type = exp.DataType.Type.VARCHAR 371 elif expression.is_int: 372 expression.type = exp.DataType.Type.INT 373 else: 374 expression.type = exp.DataType.Type.DOUBLE 375 376 return expression 377 378 @t.no_type_check 379 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 380 expression.type = target_type 381 return self._annotate_args(expression) 382 383 @t.no_type_check 384 def _annotate_by_args( 385 self, expression: E, *args: str, promote: bool = False, array: bool = False 386 ) -> E: 387 self._annotate_args(expression) 388 389 expressions: t.List[exp.Expression] = [] 390 for arg in args: 391 arg_expr = expression.args.get(arg) 392 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 393 394 last_datatype = None 395 for expr in expressions: 396 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 397 398 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 399 400 if promote: 401 if expression.type.this in exp.DataType.INTEGER_TYPES: 402 expression.type = exp.DataType.Type.BIGINT 403 elif expression.type.this in exp.DataType.FLOAT_TYPES: 404 expression.type = exp.DataType.Type.DOUBLE 405 406 if array: 407 expression.type = exp.DataType( 408 this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True 409 ) 410 411 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
236 def __init__( 237 self, 238 schema: Schema, 239 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 240 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 241 ) -> None: 242 self.schema = schema 243 self.annotators = annotators or self.ANNOTATORS 244 self.coerces_to = coerces_to or self.COERCES_TO
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.Count'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.RegexpLike'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.DateFromParts'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Round'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.TsOrDiToDi'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.TimestampSub'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Week'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.SafeConcat'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.NCHAR: 'NCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>, <Type.SMALLINT: 'SMALLINT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}}
def
annotate(self, expression: ~E) -> ~E:
246 def annotate(self, expression: E) -> E: 247 for scope in traverse_scope(expression): 248 selects = {} 249 for name, source in scope.sources.items(): 250 if not isinstance(source, Scope): 251 continue 252 if isinstance(source.expression, exp.UDTF): 253 values = [] 254 255 if isinstance(source.expression, exp.Lateral): 256 if isinstance(source.expression.this, exp.Explode): 257 values = [source.expression.this.this] 258 else: 259 values = source.expression.expressions[0].expressions 260 261 if not values: 262 continue 263 264 selects[name] = { 265 alias: column 266 for alias, column in zip( 267 source.expression.alias_column_names, 268 values, 269 ) 270 } 271 else: 272 selects[name] = { 273 select.alias_or_name: select for select in source.expression.selects 274 } 275 276 # First annotate the current scope's column references 277 for col in scope.columns: 278 if not col.table: 279 continue 280 281 source = scope.sources.get(col.table) 282 if isinstance(source, exp.Table): 283 col.type = self.schema.get_column_type(source, col) 284 elif source and col.table in selects and col.name in selects[col.table]: 285 col.type = selects[col.table][col.name].type 286 287 # Then (possibly) annotate the remaining expressions in the scope 288 self._maybe_annotate(scope.expression) 289 290 return self._maybe_annotate(expression) # This takes care of non-traversable expressions