diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index c2d6655..99888c6 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_collection, ensure_list, subclasses +from sqlglot.helper import ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -108,6 +108,7 @@ class TypeAnnotator: exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.GroupConcat: lambda self, expr: self._annotate_with_type( expr, exp.DataType.Type.VARCHAR @@ -116,6 +117,7 @@ class TypeAnnotator: expr, exp.DataType.Type.VARCHAR ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -296,9 +298,6 @@ class TypeAnnotator: return self._maybe_annotate(expression) # This takes care of non-traversable expressions def _maybe_annotate(self, expression): - if not isinstance(expression, exp.Expression): - return None - if expression.type: return expression # We've already inferred the expression's type @@ -311,9 +310,8 @@ class TypeAnnotator: ) def _annotate_args(self, expression): - for value in expression.args.values(): - for v in ensure_collection(value): - self._maybe_annotate(v) + for _, value in expression.iter_expressions(): + self._maybe_annotate(value) return expression |