sqlglot.optimizer.annotate_types
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import exp 6from sqlglot._typing import E 7from sqlglot.helper import ensure_list, subclasses 8from sqlglot.optimizer.scope import Scope, traverse_scope 9from sqlglot.schema import Schema, ensure_schema 10 11if t.TYPE_CHECKING: 12 B = t.TypeVar("B", bound=exp.Binary) 13 14 15def annotate_types( 16 expression: E, 17 schema: t.Optional[t.Dict | Schema] = None, 18 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 19 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 20) -> E: 21 """ 22 Infers the types of an expression, annotating its AST accordingly. 23 24 Example: 25 >>> import sqlglot 26 >>> schema = {"y": {"cola": "SMALLINT"}} 27 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 28 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 29 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 30 <Type.DOUBLE: 'DOUBLE'> 31 32 Args: 33 expression: Expression to annotate. 34 schema: Database schema. 35 annotators: Maps expression type to corresponding annotation function. 36 coerces_to: Maps expression type to set of types that it can be coerced into. 37 38 Returns: 39 The expression annotated with types. 40 """ 41 42 schema = ensure_schema(schema) 43 44 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) 45 46 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 48 return lambda self, e: self._annotate_with_type(e, data_type) 49 50 51class _TypeAnnotator(type): 52 def __new__(cls, clsname, bases, attrs): 53 klass = super().__new__(cls, clsname, bases, attrs) 54 55 # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): 56 # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html 57 text_precedence = ( 58 exp.DataType.Type.TEXT, 59 exp.DataType.Type.NVARCHAR, 60 exp.DataType.Type.VARCHAR, 61 exp.DataType.Type.NCHAR, 62 exp.DataType.Type.CHAR, 63 ) 64 numeric_precedence = ( 65 exp.DataType.Type.DOUBLE, 66 exp.DataType.Type.FLOAT, 67 exp.DataType.Type.DECIMAL, 68 exp.DataType.Type.BIGINT, 69 exp.DataType.Type.INT, 70 exp.DataType.Type.SMALLINT, 71 exp.DataType.Type.TINYINT, 72 ) 73 timelike_precedence = ( 74 exp.DataType.Type.TIMESTAMPLTZ, 75 exp.DataType.Type.TIMESTAMPTZ, 76 exp.DataType.Type.TIMESTAMP, 77 exp.DataType.Type.DATETIME, 78 exp.DataType.Type.DATE, 79 ) 80 81 for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): 82 coerces_to = set() 83 for data_type in type_precedence: 84 klass.COERCES_TO[data_type] = coerces_to.copy() 85 coerces_to |= {data_type} 86 87 return klass 88 89 90class TypeAnnotator(metaclass=_TypeAnnotator): 91 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 92 exp.DataType.Type.BIGINT: { 93 exp.ApproxDistinct, 94 exp.ArraySize, 95 exp.Count, 96 exp.Length, 97 }, 98 exp.DataType.Type.BOOLEAN: { 99 exp.Between, 100 exp.Boolean, 101 exp.In, 102 exp.RegexpLike, 103 }, 104 exp.DataType.Type.DATE: { 105 exp.CurrentDate, 106 exp.Date, 107 exp.DateAdd, 108 exp.DateFromParts, 109 exp.DateStrToDate, 110 exp.DateSub, 111 exp.DateTrunc, 112 exp.DiToDate, 113 exp.StrToDate, 114 exp.TimeStrToDate, 115 exp.TsOrDsToDate, 116 }, 117 exp.DataType.Type.DATETIME: { 118 exp.CurrentDatetime, 119 exp.DatetimeAdd, 120 exp.DatetimeSub, 121 }, 122 exp.DataType.Type.DOUBLE: { 123 exp.ApproxQuantile, 124 exp.Avg, 125 exp.Exp, 126 exp.Ln, 127 exp.Log, 128 exp.Log2, 129 exp.Log10, 130 exp.Pow, 131 exp.Quantile, 132 exp.Round, 133 exp.SafeDivide, 134 exp.Sqrt, 135 exp.Stddev, 136 exp.StddevPop, 137 exp.StddevSamp, 138 exp.Variance, 139 exp.VariancePop, 140 }, 141 exp.DataType.Type.INT: { 142 exp.Ceil, 143 exp.DateDiff, 144 exp.DatetimeDiff, 145 exp.Extract, 146 exp.TimestampDiff, 147 exp.TimeDiff, 148 exp.DateToDi, 149 exp.Floor, 150 exp.Levenshtein, 151 exp.StrPosition, 152 exp.TsOrDiToDi, 153 }, 154 exp.DataType.Type.TIMESTAMP: { 155 exp.CurrentTime, 156 exp.CurrentTimestamp, 157 exp.StrToTime, 158 exp.TimeAdd, 159 exp.TimeStrToTime, 160 exp.TimeSub, 161 exp.TimestampAdd, 162 exp.TimestampSub, 163 exp.UnixToTime, 164 }, 165 exp.DataType.Type.TINYINT: { 166 exp.Day, 167 exp.Month, 168 exp.Week, 169 exp.Year, 170 }, 171 exp.DataType.Type.VARCHAR: { 172 exp.ArrayConcat, 173 exp.Concat, 174 exp.ConcatWs, 175 exp.DateToDateStr, 176 exp.GroupConcat, 177 exp.Initcap, 178 exp.Lower, 179 exp.SafeConcat, 180 exp.Substring, 181 exp.TimeToStr, 182 exp.TimeToTimeStr, 183 exp.Trim, 184 exp.TsOrDsToDateStr, 185 exp.UnixToStr, 186 exp.UnixToTimeStr, 187 exp.Upper, 188 }, 189 } 190 191 ANNOTATORS: t.Dict = { 192 **{ 193 expr_type: lambda self, e: self._annotate_unary(e) 194 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 195 }, 196 **{ 197 expr_type: lambda self, e: self._annotate_binary(e) 198 for expr_type in subclasses(exp.__name__, exp.Binary) 199 }, 200 **{ 201 expr_type: _annotate_with_type_lambda(data_type) 202 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 203 for expr_type in expressions 204 }, 205 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 206 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 207 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 208 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 209 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 210 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 211 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 212 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 213 exp.Literal: lambda self, e: self._annotate_literal(e), 214 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 215 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 216 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 217 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 218 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 219 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 220 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 221 } 222 223 # Specifies what types a given type can be coerced into (autofilled) 224 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 225 226 def __init__( 227 self, 228 schema: Schema, 229 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 230 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 231 ) -> None: 232 self.schema = schema 233 self.annotators = annotators or self.ANNOTATORS 234 self.coerces_to = coerces_to or self.COERCES_TO 235 236 def annotate(self, expression: E) -> E: 237 for scope in traverse_scope(expression): 238 selects = {} 239 for name, source in scope.sources.items(): 240 if not isinstance(source, Scope): 241 continue 242 if isinstance(source.expression, exp.UDTF): 243 values = [] 244 245 if isinstance(source.expression, exp.Lateral): 246 if isinstance(source.expression.this, exp.Explode): 247 values = [source.expression.this.this] 248 else: 249 values = source.expression.expressions[0].expressions 250 251 if not values: 252 continue 253 254 selects[name] = { 255 alias: column 256 for alias, column in zip( 257 source.expression.alias_column_names, 258 values, 259 ) 260 } 261 else: 262 selects[name] = { 263 select.alias_or_name: select for select in source.expression.selects 264 } 265 266 # First annotate the current scope's column references 267 for col in scope.columns: 268 if not col.table: 269 continue 270 271 source = scope.sources.get(col.table) 272 if isinstance(source, exp.Table): 273 col.type = self.schema.get_column_type(source, col) 274 elif source and col.table in selects and col.name in selects[col.table]: 275 col.type = selects[col.table][col.name].type 276 277 # Then (possibly) annotate the remaining expressions in the scope 278 self._maybe_annotate(scope.expression) 279 280 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 281 282 def _maybe_annotate(self, expression: E) -> E: 283 if expression.type: 284 return expression # We've already inferred the expression's type 285 286 annotator = self.annotators.get(expression.__class__) 287 288 return ( 289 annotator(self, expression) 290 if annotator 291 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 292 ) 293 294 def _annotate_args(self, expression: E) -> E: 295 for _, value in expression.iter_expressions(): 296 self._maybe_annotate(value) 297 298 return expression 299 300 def _maybe_coerce( 301 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 302 ) -> exp.DataType.Type: 303 # We propagate the NULL / UNKNOWN types upwards if found 304 if isinstance(type1, exp.DataType): 305 type1 = type1.this 306 if isinstance(type2, exp.DataType): 307 type2 = type2.this 308 309 if exp.DataType.Type.NULL in (type1, type2): 310 return exp.DataType.Type.NULL 311 if exp.DataType.Type.UNKNOWN in (type1, type2): 312 return exp.DataType.Type.UNKNOWN 313 314 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore 315 316 # Note: the following "no_type_check" decorators were added because mypy was yelling due 317 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 318 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 319 320 @t.no_type_check 321 def _annotate_binary(self, expression: B) -> B: 322 self._annotate_args(expression) 323 324 left_type = expression.left.type.this 325 right_type = expression.right.type.this 326 327 if isinstance(expression, exp.Connector): 328 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 329 expression.type = exp.DataType.Type.NULL 330 elif exp.DataType.Type.NULL in (left_type, right_type): 331 expression.type = exp.DataType.build( 332 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 333 ) 334 else: 335 expression.type = exp.DataType.Type.BOOLEAN 336 elif isinstance(expression, exp.Predicate): 337 expression.type = exp.DataType.Type.BOOLEAN 338 else: 339 expression.type = self._maybe_coerce(left_type, right_type) 340 341 return expression 342 343 @t.no_type_check 344 def _annotate_unary(self, expression: E) -> E: 345 self._annotate_args(expression) 346 347 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 348 expression.type = exp.DataType.Type.BOOLEAN 349 else: 350 expression.type = expression.this.type 351 352 return expression 353 354 @t.no_type_check 355 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 356 if expression.is_string: 357 expression.type = exp.DataType.Type.VARCHAR 358 elif expression.is_int: 359 expression.type = exp.DataType.Type.INT 360 else: 361 expression.type = exp.DataType.Type.DOUBLE 362 363 return expression 364 365 @t.no_type_check 366 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 367 expression.type = target_type 368 return self._annotate_args(expression) 369 370 @t.no_type_check 371 def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: 372 self._annotate_args(expression) 373 374 expressions: t.List[exp.Expression] = [] 375 for arg in args: 376 arg_expr = expression.args.get(arg) 377 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 378 379 last_datatype = None 380 for expr in expressions: 381 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 382 383 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 384 385 if promote: 386 if expression.type.this in exp.DataType.INTEGER_TYPES: 387 expression.type = exp.DataType.Type.BIGINT 388 elif expression.type.this in exp.DataType.FLOAT_TYPES: 389 expression.type = exp.DataType.Type.DOUBLE 390 391 return expression
def
annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[sqlglot.optimizer.annotate_types.TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types( 17 expression: E, 18 schema: t.Optional[t.Dict | Schema] = None, 19 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 20 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 21) -> E: 22 """ 23 Infers the types of an expression, annotating its AST accordingly. 24 25 Example: 26 >>> import sqlglot 27 >>> schema = {"y": {"cola": "SMALLINT"}} 28 >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" 29 >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) 30 >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" 31 <Type.DOUBLE: 'DOUBLE'> 32 33 Args: 34 expression: Expression to annotate. 35 schema: Database schema. 36 annotators: Maps expression type to corresponding annotation function. 37 coerces_to: Maps expression type to set of types that it can be coerced into. 38 39 Returns: 40 The expression annotated with types. 41 """ 42 43 schema = ensure_schema(schema) 44 45 return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'>
Arguments:
- expression: Expression to annotate.
- schema: Database schema.
- annotators: Maps expression type to corresponding annotation function.
- coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:
The expression annotated with types.
class
TypeAnnotator:
91class TypeAnnotator(metaclass=_TypeAnnotator): 92 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 93 exp.DataType.Type.BIGINT: { 94 exp.ApproxDistinct, 95 exp.ArraySize, 96 exp.Count, 97 exp.Length, 98 }, 99 exp.DataType.Type.BOOLEAN: { 100 exp.Between, 101 exp.Boolean, 102 exp.In, 103 exp.RegexpLike, 104 }, 105 exp.DataType.Type.DATE: { 106 exp.CurrentDate, 107 exp.Date, 108 exp.DateAdd, 109 exp.DateFromParts, 110 exp.DateStrToDate, 111 exp.DateSub, 112 exp.DateTrunc, 113 exp.DiToDate, 114 exp.StrToDate, 115 exp.TimeStrToDate, 116 exp.TsOrDsToDate, 117 }, 118 exp.DataType.Type.DATETIME: { 119 exp.CurrentDatetime, 120 exp.DatetimeAdd, 121 exp.DatetimeSub, 122 }, 123 exp.DataType.Type.DOUBLE: { 124 exp.ApproxQuantile, 125 exp.Avg, 126 exp.Exp, 127 exp.Ln, 128 exp.Log, 129 exp.Log2, 130 exp.Log10, 131 exp.Pow, 132 exp.Quantile, 133 exp.Round, 134 exp.SafeDivide, 135 exp.Sqrt, 136 exp.Stddev, 137 exp.StddevPop, 138 exp.StddevSamp, 139 exp.Variance, 140 exp.VariancePop, 141 }, 142 exp.DataType.Type.INT: { 143 exp.Ceil, 144 exp.DateDiff, 145 exp.DatetimeDiff, 146 exp.Extract, 147 exp.TimestampDiff, 148 exp.TimeDiff, 149 exp.DateToDi, 150 exp.Floor, 151 exp.Levenshtein, 152 exp.StrPosition, 153 exp.TsOrDiToDi, 154 }, 155 exp.DataType.Type.TIMESTAMP: { 156 exp.CurrentTime, 157 exp.CurrentTimestamp, 158 exp.StrToTime, 159 exp.TimeAdd, 160 exp.TimeStrToTime, 161 exp.TimeSub, 162 exp.TimestampAdd, 163 exp.TimestampSub, 164 exp.UnixToTime, 165 }, 166 exp.DataType.Type.TINYINT: { 167 exp.Day, 168 exp.Month, 169 exp.Week, 170 exp.Year, 171 }, 172 exp.DataType.Type.VARCHAR: { 173 exp.ArrayConcat, 174 exp.Concat, 175 exp.ConcatWs, 176 exp.DateToDateStr, 177 exp.GroupConcat, 178 exp.Initcap, 179 exp.Lower, 180 exp.SafeConcat, 181 exp.Substring, 182 exp.TimeToStr, 183 exp.TimeToTimeStr, 184 exp.Trim, 185 exp.TsOrDsToDateStr, 186 exp.UnixToStr, 187 exp.UnixToTimeStr, 188 exp.Upper, 189 }, 190 } 191 192 ANNOTATORS: t.Dict = { 193 **{ 194 expr_type: lambda self, e: self._annotate_unary(e) 195 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 196 }, 197 **{ 198 expr_type: lambda self, e: self._annotate_binary(e) 199 for expr_type in subclasses(exp.__name__, exp.Binary) 200 }, 201 **{ 202 expr_type: _annotate_with_type_lambda(data_type) 203 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 204 for expr_type in expressions 205 }, 206 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 207 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 208 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 209 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 210 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 211 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 212 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 213 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 214 exp.Literal: lambda self, e: self._annotate_literal(e), 215 exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 216 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 217 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 218 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 219 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 220 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 221 exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), 222 } 223 224 # Specifies what types a given type can be coerced into (autofilled) 225 COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} 226 227 def __init__( 228 self, 229 schema: Schema, 230 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 231 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 232 ) -> None: 233 self.schema = schema 234 self.annotators = annotators or self.ANNOTATORS 235 self.coerces_to = coerces_to or self.COERCES_TO 236 237 def annotate(self, expression: E) -> E: 238 for scope in traverse_scope(expression): 239 selects = {} 240 for name, source in scope.sources.items(): 241 if not isinstance(source, Scope): 242 continue 243 if isinstance(source.expression, exp.UDTF): 244 values = [] 245 246 if isinstance(source.expression, exp.Lateral): 247 if isinstance(source.expression.this, exp.Explode): 248 values = [source.expression.this.this] 249 else: 250 values = source.expression.expressions[0].expressions 251 252 if not values: 253 continue 254 255 selects[name] = { 256 alias: column 257 for alias, column in zip( 258 source.expression.alias_column_names, 259 values, 260 ) 261 } 262 else: 263 selects[name] = { 264 select.alias_or_name: select for select in source.expression.selects 265 } 266 267 # First annotate the current scope's column references 268 for col in scope.columns: 269 if not col.table: 270 continue 271 272 source = scope.sources.get(col.table) 273 if isinstance(source, exp.Table): 274 col.type = self.schema.get_column_type(source, col) 275 elif source and col.table in selects and col.name in selects[col.table]: 276 col.type = selects[col.table][col.name].type 277 278 # Then (possibly) annotate the remaining expressions in the scope 279 self._maybe_annotate(scope.expression) 280 281 return self._maybe_annotate(expression) # This takes care of non-traversable expressions 282 283 def _maybe_annotate(self, expression: E) -> E: 284 if expression.type: 285 return expression # We've already inferred the expression's type 286 287 annotator = self.annotators.get(expression.__class__) 288 289 return ( 290 annotator(self, expression) 291 if annotator 292 else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) 293 ) 294 295 def _annotate_args(self, expression: E) -> E: 296 for _, value in expression.iter_expressions(): 297 self._maybe_annotate(value) 298 299 return expression 300 301 def _maybe_coerce( 302 self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type 303 ) -> exp.DataType.Type: 304 # We propagate the NULL / UNKNOWN types upwards if found 305 if isinstance(type1, exp.DataType): 306 type1 = type1.this 307 if isinstance(type2, exp.DataType): 308 type2 = type2.this 309 310 if exp.DataType.Type.NULL in (type1, type2): 311 return exp.DataType.Type.NULL 312 if exp.DataType.Type.UNKNOWN in (type1, type2): 313 return exp.DataType.Type.UNKNOWN 314 315 return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore 316 317 # Note: the following "no_type_check" decorators were added because mypy was yelling due 318 # to assigning Type values to expression.type (since its getter returns Optional[DataType]). 319 # This is a known mypy issue: https://github.com/python/mypy/issues/3004 320 321 @t.no_type_check 322 def _annotate_binary(self, expression: B) -> B: 323 self._annotate_args(expression) 324 325 left_type = expression.left.type.this 326 right_type = expression.right.type.this 327 328 if isinstance(expression, exp.Connector): 329 if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: 330 expression.type = exp.DataType.Type.NULL 331 elif exp.DataType.Type.NULL in (left_type, right_type): 332 expression.type = exp.DataType.build( 333 "NULLABLE", expressions=exp.DataType.build("BOOLEAN") 334 ) 335 else: 336 expression.type = exp.DataType.Type.BOOLEAN 337 elif isinstance(expression, exp.Predicate): 338 expression.type = exp.DataType.Type.BOOLEAN 339 else: 340 expression.type = self._maybe_coerce(left_type, right_type) 341 342 return expression 343 344 @t.no_type_check 345 def _annotate_unary(self, expression: E) -> E: 346 self._annotate_args(expression) 347 348 if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): 349 expression.type = exp.DataType.Type.BOOLEAN 350 else: 351 expression.type = expression.this.type 352 353 return expression 354 355 @t.no_type_check 356 def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: 357 if expression.is_string: 358 expression.type = exp.DataType.Type.VARCHAR 359 elif expression.is_int: 360 expression.type = exp.DataType.Type.INT 361 else: 362 expression.type = exp.DataType.Type.DOUBLE 363 364 return expression 365 366 @t.no_type_check 367 def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: 368 expression.type = target_type 369 return self._annotate_args(expression) 370 371 @t.no_type_check 372 def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: 373 self._annotate_args(expression) 374 375 expressions: t.List[exp.Expression] = [] 376 for arg in args: 377 arg_expr = expression.args.get(arg) 378 expressions.extend(expr for expr in ensure_list(arg_expr) if expr) 379 380 last_datatype = None 381 for expr in expressions: 382 last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) 383 384 expression.type = last_datatype or exp.DataType.Type.UNKNOWN 385 386 if promote: 387 if expression.type.this in exp.DataType.INTEGER_TYPES: 388 expression.type = exp.DataType.Type.BIGINT 389 elif expression.type.this in exp.DataType.FLOAT_TYPES: 390 expression.type = exp.DataType.Type.DOUBLE 391 392 return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[sqlglot.optimizer.annotate_types.TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
227 def __init__( 228 self, 229 schema: Schema, 230 annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, 231 coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, 232 ) -> None: 233 self.schema = schema 234 self.annotators = annotators or self.ANNOTATORS 235 self.coerces_to = coerces_to or self.COERCES_TO
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] =
{<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.Count'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.RegexpLike'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateFromParts'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.TimeStrToDate'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TsOrDiToDi'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.StrPosition'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Month'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.UnixToStr'>}}
ANNOTATORS: Dict =
{<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] =
{<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.NCHAR: 'NCHAR'>: {<Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NCHAR: 'NCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.INT: 'INT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>, <Type.BIGINT: 'BIGINT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.INT: 'INT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.TINYINT: 'TINYINT'>: {<Type.SMALLINT: 'SMALLINT'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.INT: 'INT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMP: 'TIMESTAMP'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}}
def
annotate(self, expression: ~E) -> ~E:
237 def annotate(self, expression: E) -> E: 238 for scope in traverse_scope(expression): 239 selects = {} 240 for name, source in scope.sources.items(): 241 if not isinstance(source, Scope): 242 continue 243 if isinstance(source.expression, exp.UDTF): 244 values = [] 245 246 if isinstance(source.expression, exp.Lateral): 247 if isinstance(source.expression.this, exp.Explode): 248 values = [source.expression.this.this] 249 else: 250 values = source.expression.expressions[0].expressions 251 252 if not values: 253 continue 254 255 selects[name] = { 256 alias: column 257 for alias, column in zip( 258 source.expression.alias_column_names, 259 values, 260 ) 261 } 262 else: 263 selects[name] = { 264 select.alias_or_name: select for select in source.expression.selects 265 } 266 267 # First annotate the current scope's column references 268 for col in scope.columns: 269 if not col.table: 270 continue 271 272 source = scope.sources.get(col.table) 273 if isinstance(source, exp.Table): 274 col.type = self.schema.get_column_type(source, col) 275 elif source and col.table in selects and col.name in selects[col.table]: 276 col.type = selects[col.table][col.name].type 277 278 # Then (possibly) annotate the remaining expressions in the scope 279 self._maybe_annotate(scope.expression) 280 281 return self._maybe_annotate(expression) # This takes care of non-traversable expressions