sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot._typing import E 7from sqlglot.helper import ensure_list, subclasses 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import Schema, ensure_schema 10 11if t.TYPE_CHECKING: 12 B = t.TypeVar("B", bound=exp.Binary) 13 14 15def annotate_types( 16 expression: E, 17 schema: t.Optional[t.Dict | Schema] = None, 18 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 19 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 20) -> E: 21 """ 22 Infers the types of an expression, annotating its AST accordingly. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"y": {"cola": "SMALLINT"}} 27 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 28 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 29 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 30 <Type.DOUBLE: 'DOUBLE'> 31 32 Args: 33 expression: Expression to annotate. 34 schema: Database schema. 35 annotators: Maps expression type to corresponding annotation function. 36 coerces_to: Maps expression type to set of types that it can be coerced into. 37 38 Returns: 39 The expression annotated with types. 40 """ 41 42 schema = ensure_schema(schema) 43 44 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 45 46 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 48 return lambda self, e: self._annotate_with_type(e, data_type) 49 50 51class _TypeAnnotator(type): 52 def __new__(cls, clsname, bases, attrs): 53 klass = super().__new__(cls, clsname, bases, attrs) 54 55 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 56 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 57 text_precedence = ( 58 exp.DataType.Type.TEXT, 59 exp.DataType.Type.NVARCHAR, 60 exp.DataType.Type.VARCHAR, 61 exp.DataType.Type.NCHAR, 62 exp.DataType.Type.CHAR, 63 ) 64 numeric_precedence = ( 65 exp.DataType.Type.DOUBLE, 66 exp.DataType.Type.FLOAT, 67 exp.DataType.Type.DECIMAL, 68 exp.DataType.Type.BIGINT, 69 exp.DataType.Type.INT, 70 exp.DataType.Type.SMALLINT, 71 exp.DataType.Type.TINYINT, 72 ) 73 timelike_precedence = ( 74 exp.DataType.Type.TIMESTAMPLTZ, 75 exp.DataType.Type.TIMESTAMPTZ, 76 exp.DataType.Type.TIMESTAMP, 77 exp.DataType.Type.DATETIME, 78 exp.DataType.Type.DATE, 79 ) 80 81 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 82 coerces_to = set() 83 for data_type in type_precedence: 84 klass.COERCES_TO[data_type] = coerces_to.copy() 85 coerces_to |= {data_type} 86 87 return klass 88 89 90class TypeAnnotator(metaclass=_TypeAnnotator): 91 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 92 exp.DataType.Type.BIGINT: { 93 exp.ApproxDistinct, 94 exp.ArraySize, 95 exp.Count, 96 exp.Length, 97 }, 98 exp.DataType.Type.BOOLEAN: { 99 exp.Between, 100 exp.Boolean, 101 exp.In, 102 exp.RegexpLike, 103 }, 104 exp.DataType.Type.DATE: { 105 exp.CurrentDate, 106 exp.Date, 107 exp.DateAdd, 108 exp.DateStrToDate, 109 exp.DateSub, 110 exp.DateTrunc, 111 exp.DiToDate, 112 exp.StrToDate, 113 exp.TimeStrToDate, 114 exp.TsOrDsToDate, 115 }, 116 exp.DataType.Type.DATETIME: { 117 exp.CurrentDatetime, 118 exp.DatetimeAdd, 119 exp.DatetimeSub, 120 }, 121 exp.DataType.Type.DOUBLE: { 122 exp.ApproxQuantile, 123 exp.Avg, 124 exp.Exp, 125 exp.Ln, 126 exp.Log, 127 exp.Log2, 128 exp.Log10, 129 exp.Pow, 130 exp.Quantile, 131 exp.Round, 132 exp.SafeDivide, 133 exp.Sqrt, 134 exp.Stddev, 135 exp.StddevPop, 136 exp.StddevSamp, 137 exp.Variance, 138 exp.VariancePop, 139 }, 140 exp.DataType.Type.INT: { 141 exp.Ceil, 142 exp.DateDiff, 143 exp.DatetimeDiff, 144 exp.Extract, 145 exp.TimestampDiff, 146 exp.TimeDiff, 147 exp.DateToDi, 148 exp.Floor, 149 exp.Levenshtein, 150 exp.StrPosition, 151 exp.TsOrDiToDi, 152 }, 153 exp.DataType.Type.TIMESTAMP: { 154 exp.CurrentTime, 155 exp.CurrentTimestamp, 156 exp.StrToTime, 157 exp.TimeAdd, 158 exp.TimeStrToTime, 159 exp.TimeSub, 160 exp.TimestampAdd, 161 exp.TimestampSub, 162 exp.UnixToTime, 163 }, 164 exp.DataType.Type.TINYINT: { 165 exp.Day, 166 exp.Month, 167 exp.Week, 168 exp.Year, 169 }, 170 exp.DataType.Type.VARCHAR: { 171 exp.ArrayConcat, 172 exp.Concat, 173 exp.ConcatWs, 174 exp.DateToDateStr, 175 exp.GroupConcat, 176 exp.Initcap, 177 exp.Lower, 178 exp.SafeConcat, 179 exp.Substring, 180 exp.TimeToStr, 181 exp.TimeToTimeStr, 182 exp.Trim, 183 exp.TsOrDsToDateStr, 184 exp.UnixToStr, 185 exp.UnixToTimeStr, 186 exp.Upper, 187 }, 188 } 189 190 ANNOTATORS: t.Dict = { 191 **{ 192 expr_type: lambda self, e: self._annotate_unary(e) 193 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 194 }, 195 **{ 196 expr_type: lambda self, e: self._annotate_binary(e) 197 for expr_type in subclasses(exp.__name__, exp.Binary) 198 }, 199 **{ 200 expr_type: _annotate_with_type_lambda(data_type) 201 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 202 for expr_type in expressions 203 }, 204 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 205 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 206 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 207 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 208 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 209 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 210 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 211 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 212 exp.Literal: lambda self, e: self._annotate_literal(e), 213 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 214 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 215 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 216 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 217 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 218 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 219 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 220 } 221 222 # Specifies what types a given type can be coerced into (autofilled) 223 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 224 225 def __init__( 226 self, 227 schema: Schema, 228 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 229 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 230 ) -> None: 231 self.schema = schema 232 self.annotators = annotators or self.ANNOTATORS 233 self.coerces_to = coerces_to or self.COERCES_TO 234 235 def annotate(self, expression: E) -> E: 236 for scope in traverse_scope(expression): 237 selects = {} 238 for name, source in scope.sources.items(): 239 if not isinstance(source, Scope): 240 continue 241 if isinstance(source.expression, exp.UDTF): 242 values = [] 243 244 if isinstance(source.expression, exp.Lateral): 245 if isinstance(source.expression.this, exp.Explode): 246 values = [source.expression.this.this] 247 else: 248 values = source.expression.expressions[0].expressions 249 250 if not values: 251 continue 252 253 selects[name] = { 254 alias: column 255 for alias, column in zip( 256 source.expression.alias_column_names, 257 values, 258 ) 259 } 260 else: 261 selects[name] = { 262 select.alias_or_name: select for select in source.expression.selects 263 } 264 265 # First annotate the current scope's column references 266 for col in scope.columns: 267 if not col.table: 268 continue 269 270 source = scope.sources.get(col.table) 271 if isinstance(source, exp.Table): 272 col.type = self.schema.get_column_type(source, col) 273 elif source and col.table in selects and col.name in selects[col.table]: 274 col.type = selects[col.table][col.name].type 275 276 # Then (possibly) annotate the remaining expressions in the scope 277 self._maybe_annotate(scope.expression) 278 279 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 280 281 def _maybe_annotate(self, expression: E) -> E: 282 if expression.type: 283 return expression # We've already inferred the expression's type 284 285 annotator = self.annotators.get(expression.__class__) 286 287 return ( 288 annotator(self, expression) 289 if annotator 290 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 291 ) 292 293 def _annotate_args(self, expression: E) -> E: 294 for _, value in expression.iter_expressions(): 295 self._maybe_annotate(value) 296 297 return expression 298 299 def _maybe_coerce( 300 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 301 ) -> exp.DataType.Type: 302 # We propagate the NULL / UNKNOWN types upwards if found 303 if isinstance(type1, exp.DataType): 304 type1 = type1.this 305 if isinstance(type2, exp.DataType): 306 type2 = type2.this 307 308 if exp.DataType.Type.NULL in (type1, type2): 309 return exp.DataType.Type.NULL 310 if exp.DataType.Type.UNKNOWN in (type1, type2): 311 return exp.DataType.Type.UNKNOWN 312 313 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore 314 315 # Note: the following "no_type_check" decorators were added because mypy was yelling due 316 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 317 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 318 319 @t.no_type_check 320 def _annotate_binary(self, expression: B) -> B: 321 self._annotate_args(expression) 322 323 left_type = expression.left.type.this 324 right_type = expression.right.type.this 325 326 if isinstance(expression, exp.Connector): 327 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 328 expression.type = exp.DataType.Type.NULL 329 elif exp.DataType.Type.NULL in (left_type, right_type): 330 expression.type = exp.DataType.build( 331 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 332 ) 333 else: 334 expression.type = exp.DataType.Type.BOOLEAN 335 elif isinstance(expression, exp.Predicate): 336 expression.type = exp.DataType.Type.BOOLEAN 337 else: 338 expression.type = self._maybe_coerce(left_type, right_type) 339 340 return expression 341 342 @t.no_type_check 343 def _annotate_unary(self, expression: E) -> E: 344 self._annotate_args(expression) 345 346 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 347 expression.type = exp.DataType.Type.BOOLEAN 348 else: 349 expression.type = expression.this.type 350 351 return expression 352 353 @t.no_type_check 354 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 355 if expression.is_string: 356 expression.type = exp.DataType.Type.VARCHAR 357 elif expression.is_int: 358 expression.type = exp.DataType.Type.INT 359 else: 360 expression.type = exp.DataType.Type.DOUBLE 361 362 return expression 363 364 @t.no_type_check 365 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 366 expression.type = target_type 367 return self._annotate_args(expression) 368 369 @t.no_type_check 370 def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: 371 self._annotate_args(expression) 372 373 expressions: t.List[exp.Expression] = [] 374 for arg in args: 375 arg_expr = expression.args.get(arg) 376 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 377 378 last_datatype = None 379 for expr in expressions: 380 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 381 382 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 383 384 if promote: 385 if expression.type.this in exp.DataType.INTEGER_TYPES: 386 expression.type = exp.DataType.Type.BIGINT 387 elif expression.type.this in exp.DataType.FLOAT_TYPES: 388 expression.type = exp.DataType.Type.DOUBLE 389 390 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[sqlglot.optimizer.annotate_types.TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types( 17 expression: E, 18 schema: t.Optional[t.Dict | Schema] = None, 19 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 20 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 21) -> E: 22 """ 23 Infers the types of an expression, annotating its AST accordingly. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"y": {"cola": "SMALLINT"}} 28 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 29 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 30 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 31 <Type.DOUBLE: 'DOUBLE'> 32 33 Args: 34 expression: Expression to annotate. 35 schema: Database schema. 36 annotators: Maps expression type to corresponding annotation function. 37 coerces_to: Maps expression type to set of types that it can be coerced into. 38 39 Returns: 40 The expression annotated with types. 41 """ 42 43 schema = ensure_schema(schema) 44 45 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
class
TypeAnnotator:
91class TypeAnnotator(metaclass=_TypeAnnotator): 92 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 93 exp.DataType.Type.BIGINT: { 94 exp.ApproxDistinct, 95 exp.ArraySize, 96 exp.Count, 97 exp.Length, 98 }, 99 exp.DataType.Type.BOOLEAN: { 100 exp.Between, 101 exp.Boolean, 102 exp.In, 103 exp.RegexpLike, 104 }, 105 exp.DataType.Type.DATE: { 106 exp.CurrentDate, 107 exp.Date, 108 exp.DateAdd, 109 exp.DateStrToDate, 110 exp.DateSub, 111 exp.DateTrunc, 112 exp.DiToDate, 113 exp.StrToDate, 114 exp.TimeStrToDate, 115 exp.TsOrDsToDate, 116 }, 117 exp.DataType.Type.DATETIME: { 118 exp.CurrentDatetime, 119 exp.DatetimeAdd, 120 exp.DatetimeSub, 121 }, 122 exp.DataType.Type.DOUBLE: { 123 exp.ApproxQuantile, 124 exp.Avg, 125 exp.Exp, 126 exp.Ln, 127 exp.Log, 128 exp.Log2, 129 exp.Log10, 130 exp.Pow, 131 exp.Quantile, 132 exp.Round, 133 exp.SafeDivide, 134 exp.Sqrt, 135 exp.Stddev, 136 exp.StddevPop, 137 exp.StddevSamp, 138 exp.Variance, 139 exp.VariancePop, 140 }, 141 exp.DataType.Type.INT: { 142 exp.Ceil, 143 exp.DateDiff, 144 exp.DatetimeDiff, 145 exp.Extract, 146 exp.TimestampDiff, 147 exp.TimeDiff, 148 exp.DateToDi, 149 exp.Floor, 150 exp.Levenshtein, 151 exp.StrPosition, 152 exp.TsOrDiToDi, 153 }, 154 exp.DataType.Type.TIMESTAMP: { 155 exp.CurrentTime, 156 exp.CurrentTimestamp, 157 exp.StrToTime, 158 exp.TimeAdd, 159 exp.TimeStrToTime, 160 exp.TimeSub, 161 exp.TimestampAdd, 162 exp.TimestampSub, 163 exp.UnixToTime, 164 }, 165 exp.DataType.Type.TINYINT: { 166 exp.Day, 167 exp.Month, 168 exp.Week, 169 exp.Year, 170 }, 171 exp.DataType.Type.VARCHAR: { 172 exp.ArrayConcat, 173 exp.Concat, 174 exp.ConcatWs, 175 exp.DateToDateStr, 176 exp.GroupConcat, 177 exp.Initcap, 178 exp.Lower, 179 exp.SafeConcat, 180 exp.Substring, 181 exp.TimeToStr, 182 exp.TimeToTimeStr, 183 exp.Trim, 184 exp.TsOrDsToDateStr, 185 exp.UnixToStr, 186 exp.UnixToTimeStr, 187 exp.Upper, 188 }, 189 } 190 191 ANNOTATORS: t.Dict = { 192 **{ 193 expr_type: lambda self, e: self._annotate_unary(e) 194 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 195 }, 196 **{ 197 expr_type: lambda self, e: self._annotate_binary(e) 198 for expr_type in subclasses(exp.__name__, exp.Binary) 199 }, 200 **{ 201 expr_type: _annotate_with_type_lambda(data_type) 202 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 203 for expr_type in expressions 204 }, 205 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 206 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 207 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 208 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 209 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 210 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 211 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 212 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 213 exp.Literal: lambda self, e: self._annotate_literal(e), 214 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 215 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 216 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 217 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 218 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 219 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 220 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 221 } 222 223 # Specifies what types a given type can be coerced into (autofilled) 224 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 225 226 def __init__( 227 self, 228 schema: Schema, 229 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 230 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 231 ) -> None: 232 self.schema = schema 233 self.annotators = annotators or self.ANNOTATORS 234 self.coerces_to = coerces_to or self.COERCES_TO 235 236 def annotate(self, expression: E) -> E: 237 for scope in traverse_scope(expression): 238 selects = {} 239 for name, source in scope.sources.items(): 240 if not isinstance(source, Scope): 241 continue 242 if isinstance(source.expression, exp.UDTF): 243 values = [] 244 245 if isinstance(source.expression, exp.Lateral): 246 if isinstance(source.expression.this, exp.Explode): 247 values = [source.expression.this.this] 248 else: 249 values = source.expression.expressions[0].expressions 250 251 if not values: 252 continue 253 254 selects[name] = { 255 alias: column 256 for alias, column in zip( 257 source.expression.alias_column_names, 258 values, 259 ) 260 } 261 else: 262 selects[name] = { 263 select.alias_or_name: select for select in source.expression.selects 264 } 265 266 # First annotate the current scope's column references 267 for col in scope.columns: 268 if not col.table: 269 continue 270 271 source = scope.sources.get(col.table) 272 if isinstance(source, exp.Table): 273 col.type = self.schema.get_column_type(source, col) 274 elif source and col.table in selects and col.name in selects[col.table]: 275 col.type = selects[col.table][col.name].type 276 277 # Then (possibly) annotate the remaining expressions in the scope 278 self._maybe_annotate(scope.expression) 279 280 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 281 282 def _maybe_annotate(self, expression: E) -> E: 283 if expression.type: 284 return expression # We've already inferred the expression's type 285 286 annotator = self.annotators.get(expression.__class__) 287 288 return ( 289 annotator(self, expression) 290 if annotator 291 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 292 ) 293 294 def _annotate_args(self, expression: E) -> E: 295 for _, value in expression.iter_expressions(): 296 self._maybe_annotate(value) 297 298 return expression 299 300 def _maybe_coerce( 301 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 302 ) -> exp.DataType.Type: 303 # We propagate the NULL / UNKNOWN types upwards if found 304 if isinstance(type1, exp.DataType): 305 type1 = type1.this 306 if isinstance(type2, exp.DataType): 307 type2 = type2.this 308 309 if exp.DataType.Type.NULL in (type1, type2): 310 return exp.DataType.Type.NULL 311 if exp.DataType.Type.UNKNOWN in (type1, type2): 312 return exp.DataType.Type.UNKNOWN 313 314 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore 315 316 # Note: the following "no_type_check" decorators were added because mypy was yelling due 317 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 318 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 319 320 @t.no_type_check 321 def _annotate_binary(self, expression: B) -> B: 322 self._annotate_args(expression) 323 324 left_type = expression.left.type.this 325 right_type = expression.right.type.this 326 327 if isinstance(expression, exp.Connector): 328 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 329 expression.type = exp.DataType.Type.NULL 330 elif exp.DataType.Type.NULL in (left_type, right_type): 331 expression.type = exp.DataType.build( 332 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 333 ) 334 else: 335 expression.type = exp.DataType.Type.BOOLEAN 336 elif isinstance(expression, exp.Predicate): 337 expression.type = exp.DataType.Type.BOOLEAN 338 else: 339 expression.type = self._maybe_coerce(left_type, right_type) 340 341 return expression 342 343 @t.no_type_check 344 def _annotate_unary(self, expression: E) -> E: 345 self._annotate_args(expression) 346 347 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 348 expression.type = exp.DataType.Type.BOOLEAN 349 else: 350 expression.type = expression.this.type 351 352 return expression 353 354 @t.no_type_check 355 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 356 if expression.is_string: 357 expression.type = exp.DataType.Type.VARCHAR 358 elif expression.is_int: 359 expression.type = exp.DataType.Type.INT 360 else: 361 expression.type = exp.DataType.Type.DOUBLE 362 363 return expression 364 365 @t.no_type_check 366 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 367 expression.type = target_type 368 return self._annotate_args(expression) 369 370 @t.no_type_check 371 def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: 372 self._annotate_args(expression) 373 374 expressions: t.List[exp.Expression] = [] 375 for arg in args: 376 arg_expr = expression.args.get(arg) 377 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 378 379 last_datatype = None 380 for expr in expressions: 381 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 382 383 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 384 385 if promote: 386 if expression.type.this in exp.DataType.INTEGER_TYPES: 387 expression.type = exp.DataType.Type.BIGINT 388 elif expression.type.this in exp.DataType.FLOAT_TYPES: 389 expression.type = exp.DataType.Type.DOUBLE 390 391 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[sqlglot.optimizer.annotate_types.TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
226 def __init__( 227 self, 228 schema: Schema, 229 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 230 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 231 ) -> None: 232 self.schema = schema 233 self.annotators = annotators or self.ANNOTATORS 234 self.coerces_to = coerces_to or self.COERCES_TO
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Count'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.RegexpLike'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.TimeStrToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Log'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.DateDiff'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Week'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.UnixToStr'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.NCHAR: 'NCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.INT: 'INT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.INT: 'INT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TINYINT: 'TINYINT'>: {<Type.SMALLINT: 'SMALLINT'>, <Type.BIGINT: 'BIGINT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.INT: 'INT'>, <Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}}
def
annotate(self, expression: ~E) -> ~E:
236 def annotate(self, expression: E) -> E: 237 for scope in traverse_scope(expression): 238 selects = {} 239 for name, source in scope.sources.items(): 240 if not isinstance(source, Scope): 241 continue 242 if isinstance(source.expression, exp.UDTF): 243 values = [] 244 245 if isinstance(source.expression, exp.Lateral): 246 if isinstance(source.expression.this, exp.Explode): 247 values = [source.expression.this.this] 248 else: 249 values = source.expression.expressions[0].expressions 250 251 if not values: 252 continue 253 254 selects[name] = { 255 alias: column 256 for alias, column in zip( 257 source.expression.alias_column_names, 258 values, 259 ) 260 } 261 else: 262 selects[name] = { 263 select.alias_or_name: select for select in source.expression.selects 264 } 265 266 # First annotate the current scope's column references 267 for col in scope.columns: 268 if not col.table: 269 continue 270 271 source = scope.sources.get(col.table) 272 if isinstance(source, exp.Table): 273 col.type = self.schema.get_column_type(source, col) 274 elif source and col.table in selects and col.name in selects[col.table]: 275 col.type = selects[col.table][col.name].type 276 277 # Then (possibly) annotate the remaining expressions in the scope 278 self._maybe_annotate(scope.expression) 279 280 return self._maybe_annotate(expression) # This takes care of non-traversable expressions