summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py74
1 files changed, 55 insertions, 19 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index b983bf9..9299132 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,4 +1,5 @@
import inspect
+import numbers
import re
import sys
from collections import deque
@@ -6,7 +7,7 @@ from copy import deepcopy
from enum import auto
from sqlglot.errors import ParseError
-from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list
+from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list, list_get
class _Expression(type):
@@ -350,7 +351,8 @@ class Expression(metaclass=_Expression):
Args:
fun (function): a function which takes a node as an argument and returns a
- new transformed node or the same node without modifications.
+ new transformed node or the same node without modifications. If the function
+ returns None, then the corresponding node will be removed from the syntax tree.
copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
modified in place.
@@ -360,9 +362,7 @@ class Expression(metaclass=_Expression):
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)
- if new_node is None:
- raise ValueError("A transformed node cannot be None")
- if not isinstance(new_node, Expression):
+ if new_node is None or not isinstance(new_node, Expression):
return new_node
if new_node is not node:
new_node.parent = node.parent
@@ -843,10 +843,6 @@ class Ordered(Expression):
arg_types = {"this": True, "desc": True, "nulls_first": True}
-class Properties(Expression):
- arg_types = {"expressions": True}
-
-
class Property(Expression):
arg_types = {"this": True, "value": True}
@@ -891,6 +887,42 @@ class AnonymousProperty(Property):
pass
+class Properties(Expression):
+ arg_types = {"expressions": True}
+
+ PROPERTY_KEY_MAPPING = {
+ "AUTO_INCREMENT": AutoIncrementProperty,
+ "CHARACTER_SET": CharacterSetProperty,
+ "COLLATE": CollateProperty,
+ "COMMENT": SchemaCommentProperty,
+ "ENGINE": EngineProperty,
+ "FORMAT": FileFormatProperty,
+ "LOCATION": LocationProperty,
+ "PARTITIONED_BY": PartitionedByProperty,
+ "TABLE_FORMAT": TableFormatProperty,
+ }
+
+ @classmethod
+ def from_dict(cls, properties_dict):
+ expressions = []
+ for key, value in properties_dict.items():
+ property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
+ expressions.append(property_cls(this=Literal.string(key), value=cls._convert_value(value)))
+ return cls(expressions=expressions)
+
+ @staticmethod
+ def _convert_value(value):
+ if isinstance(value, Expression):
+ return value
+ if isinstance(value, str):
+ return Literal.string(value)
+ if isinstance(value, numbers.Number):
+ return Literal.number(value)
+ if isinstance(value, list):
+ return Tuple(expressions=[_convert_value(v) for v in value])
+ raise ValueError(f"Unsupported type '{type(value)}' for value '{value}'")
+
+
class Qualify(Expression):
pass
@@ -1562,15 +1594,7 @@ class Select(Subqueryable, Expression):
)
properties_expression = None
if properties:
- properties_str = " ".join(
- [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
- )
- properties_expression = maybe_parse(
- properties_str,
- into=Properties,
- dialect=dialect,
- **opts,
- )
+ properties_expression = Properties.from_dict(properties)
return Create(
this=table_expression,
@@ -1650,6 +1674,10 @@ class Star(Expression):
return "*"
+class Parameter(Expression):
+ pass
+
+
class Placeholder(Expression):
arg_types = {}
@@ -1688,6 +1716,7 @@ class DataType(Expression):
INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
+ TIMESTAMPLTZ = auto()
DATE = auto()
DATETIME = auto()
ARRAY = auto()
@@ -1702,6 +1731,13 @@ class DataType(Expression):
SERIAL = auto()
SMALLSERIAL = auto()
BIGSERIAL = auto()
+ XML = auto()
+ UNIQUEIDENTIFIER = auto()
+ MONEY = auto()
+ SMALLMONEY = auto()
+ ROWVERSION = auto()
+ IMAGE = auto()
+ SQL_VARIANT = auto()
@classmethod
def build(cls, dtype, **kwargs):
@@ -2976,7 +3012,7 @@ def replace_children(expression, fun):
else:
new_child_nodes.append(cn)
- expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0]
+ expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
def column_table_names(expression):