diff options
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 106 |
1 files changed, 86 insertions, 20 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 96b32f1..7249574 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -43,14 +43,14 @@ class Expression(metaclass=_Expression): key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type", "comments") + __slots__ = ("args", "parent", "arg_key", "comments", "_type") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None - self.type = None self.comments = None + self._type: t.Optional[DataType] = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) @@ -122,6 +122,16 @@ class Expression(metaclass=_Expression): return "NULL" return self.alias or self.name + @property + def type(self) -> t.Optional[DataType]: + return self._type + + @type.setter + def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None: + if dtype and not isinstance(dtype, DataType): + dtype = DataType.build(dtype) + self._type = dtype # type: ignore + def __deepcopy__(self, memo): copy = self.__class__(**deepcopy(self.args)) copy.comments = self.comments @@ -348,7 +358,7 @@ class Expression(metaclass=_Expression): indent += "".join([" "] * level) left = f"({self.key.upper()} " - args = { + args: t.Dict[str, t.Any] = { k: ", ".join( v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) for v in ensure_collection(vs) @@ -612,6 +622,7 @@ class Create(Expression): "properties": False, "temporary": False, "transient": False, + "external": False, "replace": False, "unique": False, "materialized": False, @@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind): pass +class EncodeColumnConstraint(ColumnConstraintKind): + pass + + class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = {"this": True, "expression": False} class NotNullColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"allow_null": False} class PrimaryKeyColumnConstraint(ColumnConstraintKind): @@ -766,7 +781,7 @@ class Constraint(Expression): class Delete(Expression): - arg_types = {"with": False, "this": True, "using": False, "where": False} + arg_types = {"with": False, "this": False, "using": False, "where": False} class Drop(Expression): @@ -850,7 +865,7 @@ class Insert(Expression): arg_types = { "with": False, "this": True, - "expression": True, + "expression": False, "overwrite": False, "exists": False, "partition": False, @@ -1125,6 +1140,27 @@ class VolatilityProperty(Property): arg_types = {"this": True} +class RowFormatDelimitedProperty(Property): + # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml + arg_types = { + "fields": False, + "escaped": False, + "collection_items": False, + "map_keys": False, + "lines": False, + "null": False, + "serde": False, + } + + +class RowFormatSerdeProperty(Property): + arg_types = {"this": True} + + +class SerdeProperties(Property): + arg_types = {"expressions": True} + + class Properties(Expression): arg_types = {"expressions": True} @@ -1169,18 +1205,6 @@ class Reference(Expression): arg_types = {"this": True, "expressions": True} -class RowFormat(Expression): - # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml - arg_types = { - "fields": False, - "escaped": False, - "collection_items": False, - "map_keys": False, - "lines": False, - "null": False, - } - - class Tuple(Expression): arg_types = {"expressions": False} @@ -1208,6 +1232,9 @@ class Subqueryable(Unionable): alias=TableAlias(this=to_identifier(alias)), ) + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + raise NotImplementedError + @property def ctes(self): with_ = self.args.get("with") @@ -1320,6 +1347,32 @@ class Union(Subqueryable): **QUERY_MODIFIERS, } + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + """ + Set the LIMIT expression. + + Example: + >>> select("1").union(select("1")).limit(1).sql() + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' + + Args: + expression (str | int | Expression): the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: The limited subqueryable. + """ + return ( + select("*") + .from_(self.subquery(alias="_l_0", copy=copy)) + .limit(expression, dialect=dialect, copy=False, **opts) + ) + @property def named_selects(self): return self.this.unnest().named_selects @@ -1356,7 +1409,7 @@ class Unnest(UDTF): class Update(Expression): arg_types = { "with": False, - "this": True, + "this": False, "expressions": True, "from": False, "where": False, @@ -2057,15 +2110,20 @@ class DataType(Expression): Type.TEXT, } - NUMERIC_TYPES = { + INTEGER_TYPES = { Type.INT, Type.TINYINT, Type.SMALLINT, Type.BIGINT, + } + + FLOAT_TYPES = { Type.FLOAT, Type.DOUBLE, } + NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES} + TEMPORAL_TYPES = { Type.TIMESTAMP, Type.TIMESTAMPTZ, @@ -2968,6 +3026,14 @@ class Use(Expression): pass +class Merge(Expression): + arg_types = {"this": True, "using": True, "on": True, "expressions": True} + + +class When(Func): + arg_types = {"this": True, "then": True} + + def _norm_args(expression): args = {} |