diff options
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 111 |
1 files changed, 65 insertions, 46 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index beafca8..96b32f1 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comment") + __slots__ = ("args", "parent", "arg_key", "type", "comments") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None - self.comment = None + self.comments = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -88,19 +88,6 @@ class Expression(metaclass=_Expression): return field.this return "" - def find_comment(self, key: str) -> str: - """ - Finds the comment that is attached to a specified child node. - - Args: - key: the key of the target child node (e.g. "this", "expression", etc). - - Returns: - The comment attached to the child node, or the empty string, if it doesn't exist. - """ - field = self.args.get(key) - return field.comment if isinstance(field, Expression) else "" - @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -137,7 +124,7 @@ class Expression(metaclass=_Expression): def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) - copy.comment = self.comment + copy.comments = self.comments copy.type = self.type return copy @@ -369,7 +356,7 @@ class Expression(metaclass=_Expression): ) for k, vs in self.args.items() } - args["comment"] = self.comment + args["comments"] = self.comments args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} @@ -767,7 +754,7 @@ class NotNullColumnConstraint(ColumnConstraintKind): class PrimaryKeyColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"desc": False} class UniqueColumnConstraint(ColumnConstraintKind): @@ -819,6 +806,12 @@ class Unique(Expression): arg_types = {"expressions": True} +# https://www.postgresql.org/docs/9.1/sql-selectinto.html +# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples +class Into(Expression): + arg_types = {"this": True, "temporary": False, "unlogged": False} + + class From(Expression): arg_types = {"expressions": True} @@ -1065,67 +1058,67 @@ class Property(Expression): class TableFormatProperty(Property): - pass + arg_types = {"this": True} class PartitionedByProperty(Property): - pass + arg_types = {"this": True} class FileFormatProperty(Property): - pass + arg_types = {"this": True} class DistKeyProperty(Property): - pass + arg_types = {"this": True} class SortKeyProperty(Property): - pass + arg_types = {"this": True, "compound": False} class DistStyleProperty(Property): - pass + arg_types = {"this": True} + + +class LikeProperty(Property): + arg_types = {"this": True, "expressions": False} class LocationProperty(Property): - pass + arg_types = {"this": True} class EngineProperty(Property): - pass + arg_types = {"this": True} class AutoIncrementProperty(Property): - pass + arg_types = {"this": True} class CharacterSetProperty(Property): - arg_types = {"this": True, "value": True, "default": True} + arg_types = {"this": True, "default": True} class CollateProperty(Property): - pass + arg_types = {"this": True} class SchemaCommentProperty(Property): - pass - - -class AnonymousProperty(Property): - pass + arg_types = {"this": True} class ReturnsProperty(Property): - arg_types = {"this": True, "value": True, "is_table": False} + arg_types = {"this": True, "is_table": False} class LanguageProperty(Property): - pass + arg_types = {"this": True} class ExecuteAsProperty(Property): - pass + arg_types = {"this": True} class VolatilityProperty(Property): @@ -1135,27 +1128,36 @@ class VolatilityProperty(Property): class Properties(Expression): arg_types = {"expressions": True} - PROPERTY_KEY_MAPPING = { + NAME_TO_PROPERTY = { "AUTO_INCREMENT": AutoIncrementProperty, - "CHARACTER_SET": CharacterSetProperty, + "CHARACTER SET": CharacterSetProperty, "COLLATE": CollateProperty, "COMMENT": SchemaCommentProperty, + "DISTKEY": DistKeyProperty, + "DISTSTYLE": DistStyleProperty, "ENGINE": EngineProperty, + "EXECUTE AS": ExecuteAsProperty, "FORMAT": FileFormatProperty, + "LANGUAGE": LanguageProperty, "LOCATION": LocationProperty, "PARTITIONED_BY": PartitionedByProperty, - "TABLE_FORMAT": TableFormatProperty, - "DISTKEY": DistKeyProperty, - "DISTSTYLE": DistStyleProperty, + "RETURNS": ReturnsProperty, "SORTKEY": SortKeyProperty, + "TABLE_FORMAT": TableFormatProperty, } + PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} + @classmethod def from_dict(cls, properties_dict) -> Properties: expressions = [] for key, value in properties_dict.items(): - property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) - expressions.append(property_cls(this=Literal.string(key), value=convert(value))) + property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) + if property_cls: + expressions.append(property_cls(this=convert(value))) + else: + expressions.append(Property(this=Literal.string(key), value=convert(value))) + return cls(expressions=expressions) @@ -1383,6 +1385,7 @@ class Select(Subqueryable): "expressions": False, "hint": False, "distinct": False, + "into": False, "from": False, **QUERY_MODIFIERS, } @@ -2015,6 +2018,7 @@ class DataType(Expression): DECIMAL = auto() BOOLEAN = auto() JSON = auto() + JSONB = auto() INTERVAL = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() @@ -2029,6 +2033,7 @@ class DataType(Expression): STRUCT = auto() NULLABLE = auto() HLLSKETCH = auto() + HSTORE = auto() SUPER = auto() SERIAL = auto() SMALLSERIAL = auto() @@ -2109,7 +2114,7 @@ class Transaction(Command): class Commit(Command): - arg_types = {} # type: ignore + arg_types = {"chain": False} class Rollback(Command): @@ -2442,7 +2447,7 @@ class ArrayFilter(Func): class ArraySize(Func): - pass + arg_types = {"this": True, "expression": False} class ArraySort(Func): @@ -2726,6 +2731,16 @@ class VarMap(Func): is_var_len_args = True +class Matches(Func): + """Oracle/Snowflake decode. + https://docs.oracle.com/cd/B19306_01/server.102/b14200/functions040.htm + Pattern matching MATCHES(value, search1, result1, ...searchN, resultN, else) + """ + + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + class Max(AggFunc): pass @@ -2785,6 +2800,10 @@ class Round(Func): arg_types = {"this": True, "decimals": False} +class RowNumber(Func): + arg_types: t.Dict[str, t.Any] = {} + + class SafeDivide(Func): arg_types = {"this": True, "expression": True} |