diff options
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 74 |
1 files changed, 55 insertions, 19 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b983bf9..9299132 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,4 +1,5 @@ import inspect +import numbers import re import sys from collections import deque @@ -6,7 +7,7 @@ from copy import deepcopy from enum import auto from sqlglot.errors import ParseError -from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list +from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get class _Expression(type): @@ -350,7 +351,8 @@ class Expression(metaclass=_Expression): Args: fun (function): a function which takes a node as an argument and returns a - new transformed node or the same node without modifications. + new transformed node or the same node without modifications. If the function + returns None, then the corresponding node will be removed from the syntax tree. copy (bool): if set to True a new tree instance is constructed, otherwise the tree is modified in place. @@ -360,9 +362,7 @@ class Expression(metaclass=_Expression): node = self.copy() if copy else self new_node = fun(node, *args, **kwargs) - if new_node is None: - raise ValueError("A transformed node cannot be None") - if not isinstance(new_node, Expression): + if new_node is None or not isinstance(new_node, Expression): return new_node if new_node is not node: new_node.parent = node.parent @@ -843,10 +843,6 @@ class Ordered(Expression): arg_types = {"this": True, "desc": True, "nulls_first": True} -class Properties(Expression): - arg_types = {"expressions": True} - - class Property(Expression): arg_types = {"this": True, "value": True} @@ -891,6 +887,42 @@ class AnonymousProperty(Property): pass +class Properties(Expression): + arg_types = {"expressions": True} + + PROPERTY_KEY_MAPPING = { + "AUTO_INCREMENT": AutoIncrementProperty, + "CHARACTER_SET": CharacterSetProperty, + "COLLATE": CollateProperty, + "COMMENT": SchemaCommentProperty, + "ENGINE": EngineProperty, + "FORMAT": FileFormatProperty, + "LOCATION": LocationProperty, + "PARTITIONED_BY": PartitionedByProperty, + "TABLE_FORMAT": TableFormatProperty, + } + + @classmethod + def from_dict(cls, properties_dict): + expressions = [] + for key, value in properties_dict.items(): + property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) + expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value))) + return cls(expressions=expressions) + + @staticmethod + def _convert_value(value): + if isinstance(value, Expression): + return value + if isinstance(value, str): + return Literal.string(value) + if isinstance(value, numbers.Number): + return Literal.number(value) + if isinstance(value, list): + return Tuple(expressions=[_convert_value(v) for v in value]) + raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'") + + class Qualify(Expression): pass @@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression): ) properties_expression = None if properties: - properties_str = " ".join( - [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()] - ) - properties_expression = maybe_parse( - properties_str, - into=Properties, - dialect=dialect, - **opts, - ) + properties_expression = Properties.from_dict(properties) return Create( this=table_expression, @@ -1650,6 +1674,10 @@ class Star(Expression): return "*" +class Parameter(Expression): + pass + + class Placeholder(Expression): arg_types = {} @@ -1688,6 +1716,7 @@ class DataType(Expression): INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() + TIMESTAMPLTZ = auto() DATE = auto() DATETIME = auto() ARRAY = auto() @@ -1702,6 +1731,13 @@ class DataType(Expression): SERIAL = auto() SMALLSERIAL = auto() BIGSERIAL = auto() + XML = auto() + UNIQUEIDENTIFIER = auto() + MONEY = auto() + SMALLMONEY = auto() + ROWVERSION = auto() + IMAGE = auto() + SQL_VARIANT = auto() @classmethod def build(cls, dtype, **kwargs): @@ -2976,7 +3012,7 @@ def replace_children(expression, fun): else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0] + expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) def column_table_names(expression): |