summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
new file mode 100644
index 0000000..3f5f089
--- /dev/null
+++ b/sqlglot/optimizer/annotate_types.py
@@ -0,0 +1,162 @@
+from sqlglot import exp
+from sqlglot.helper import ensure_list, subclasses
+
+
+def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
+ """
+ Recursively infer & annotate types in an expression syntax tree against a schema.
+
+ (TODO -- replace this with a better example after adding some functionality)
+ Example:
+ >>> import sqlglot
+ >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3'))
+ >>> annotated_expression.type
+ <Type.DOUBLE: 'DOUBLE'>
+
+ Args:
+ expression (sqlglot.Expression): Expression to annotate.
+ schema (dict|sqlglot.optimizer.Schema): Database schema.
+ annotators (dict): Maps expression type to corresponding annotation function.
+ coerces_to (dict): Maps expression type to set of types that it can be coerced into.
+ Returns:
+ sqlglot.Expression: expression annotated with types
+ """
+
+ return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
+
+
+class TypeAnnotator:
+ ANNOTATORS = {
+ **{
+ expr_type: lambda self, expr: self._annotate_unary(expr)
+ for expr_type in subclasses(exp.__name__, exp.Unary)
+ },
+ **{
+ expr_type: lambda self, expr: self._annotate_binary(expr)
+ for expr_type in subclasses(exp.__name__, exp.Binary)
+ },
+ exp.Cast: lambda self, expr: self._annotate_cast(expr),
+ exp.DataType: lambda self, expr: self._annotate_data_type(expr),
+ exp.Literal: lambda self, expr: self._annotate_literal(expr),
+ exp.Boolean: lambda self, expr: self._annotate_boolean(expr),
+ }
+
+ # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
+ COERCES_TO = {
+ # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
+ exp.DataType.Type.TEXT: set(),
+ exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
+ exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
+ exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
+ exp.DataType.Type.CHAR: {
+ exp.DataType.Type.NCHAR,
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NVARCHAR,
+ exp.DataType.Type.TEXT,
+ },
+ # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
+ exp.DataType.Type.DOUBLE: set(),
+ exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
+ exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
+ exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
+ exp.DataType.Type.INT: {
+ exp.DataType.Type.BIGINT,
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DOUBLE,
+ },
+ exp.DataType.Type.SMALLINT: {
+ exp.DataType.Type.INT,
+ exp.DataType.Type.BIGINT,
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DOUBLE,
+ },
+ exp.DataType.Type.TINYINT: {
+ exp.DataType.Type.SMALLINT,
+ exp.DataType.Type.INT,
+ exp.DataType.Type.BIGINT,
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DOUBLE,
+ },
+ # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
+ exp.DataType.Type.TIMESTAMPLTZ: set(),
+ exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
+ exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
+ exp.DataType.Type.DATETIME: {
+ exp.DataType.Type.TIMESTAMP,
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ },
+ exp.DataType.Type.DATE: {
+ exp.DataType.Type.DATETIME,
+ exp.DataType.Type.TIMESTAMP,
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ },
+ }
+
+ def __init__(self, schema=None, annotators=None, coerces_to=None):
+ self.schema = schema
+ self.annotators = annotators or self.ANNOTATORS
+ self.coerces_to = coerces_to or self.COERCES_TO
+
+ def annotate(self, expression):
+ if not isinstance(expression, exp.Expression):
+ return None
+
+ annotator = self.annotators.get(expression.__class__)
+ return annotator(self, expression) if annotator else self._annotate_args(expression)
+
+ def _annotate_args(self, expression):
+ for value in expression.args.values():
+ for v in ensure_list(value):
+ self.annotate(v)
+
+ return expression
+
+ def _annotate_cast(self, expression):
+ expression.type = expression.args["to"].this
+ return self._annotate_args(expression)
+
+ def _annotate_data_type(self, expression):
+ expression.type = expression.this
+ return self._annotate_args(expression)
+
+ def _maybe_coerce(self, type1, type2):
+ return type2 if type2 in self.coerces_to[type1] else type1
+
+ def _annotate_binary(self, expression):
+ self._annotate_args(expression)
+
+ if isinstance(expression, (exp.Condition, exp.Predicate)):
+ expression.type = exp.DataType.Type.BOOLEAN
+ else:
+ expression.type = self._maybe_coerce(expression.left.type, expression.right.type)
+
+ return expression
+
+ def _annotate_unary(self, expression):
+ self._annotate_args(expression)
+
+ if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
+ expression.type = exp.DataType.Type.BOOLEAN
+ else:
+ expression.type = expression.this.type
+
+ return expression
+
+ def _annotate_literal(self, expression):
+ if expression.is_string:
+ expression.type = exp.DataType.Type.VARCHAR
+ elif expression.is_int:
+ expression.type = exp.DataType.Type.INT
+ else:
+ expression.type = exp.DataType.Type.DOUBLE
+
+ return expression
+
+ def _annotate_boolean(self, expression):
+ expression.type = exp.DataType.Type.BOOLEAN
+ return expression