summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py106
1 files changed, 86 insertions, 20 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 96b32f1..7249574 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "type", "comments")
+ __slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
- self.type = None
self.comments = None
+ self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@@ -122,6 +122,16 @@ class Expression(metaclass=_Expression):
return "NULL"
return self.alias or self.name
+ @property
+ def type(self) -> t.Optional[DataType]:
+ return self._type
+
+ @type.setter
+ def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None:
+ if dtype and not isinstance(dtype, DataType):
+ dtype = DataType.build(dtype)
+ self._type = dtype # type: ignore
+
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments
@@ -348,7 +358,7 @@ class Expression(metaclass=_Expression):
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
- args = {
+ args: t.Dict[str, t.Any] = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_collection(vs)
@@ -612,6 +622,7 @@ class Create(Expression):
"properties": False,
"temporary": False,
"transient": False,
+ "external": False,
"replace": False,
"unique": False,
"materialized": False,
@@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind):
pass
+class EncodeColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind):
- pass
+ arg_types = {"allow_null": False}
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
@@ -766,7 +781,7 @@ class Constraint(Expression):
class Delete(Expression):
- arg_types = {"with": False, "this": True, "using": False, "where": False}
+ arg_types = {"with": False, "this": False, "using": False, "where": False}
class Drop(Expression):
@@ -850,7 +865,7 @@ class Insert(Expression):
arg_types = {
"with": False,
"this": True,
- "expression": True,
+ "expression": False,
"overwrite": False,
"exists": False,
"partition": False,
@@ -1125,6 +1140,27 @@ class VolatilityProperty(Property):
arg_types = {"this": True}
+class RowFormatDelimitedProperty(Property):
+ # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
+ arg_types = {
+ "fields": False,
+ "escaped": False,
+ "collection_items": False,
+ "map_keys": False,
+ "lines": False,
+ "null": False,
+ "serde": False,
+ }
+
+
+class RowFormatSerdeProperty(Property):
+ arg_types = {"this": True}
+
+
+class SerdeProperties(Property):
+ arg_types = {"expressions": True}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1169,18 +1205,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
-class RowFormat(Expression):
- # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
- arg_types = {
- "fields": False,
- "escaped": False,
- "collection_items": False,
- "map_keys": False,
- "lines": False,
- "null": False,
- }
-
-
class Tuple(Expression):
arg_types = {"expressions": False}
@@ -1208,6 +1232,9 @@ class Subqueryable(Unionable):
alias=TableAlias(this=to_identifier(alias)),
)
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ raise NotImplementedError
+
@property
def ctes(self):
with_ = self.args.get("with")
@@ -1320,6 +1347,32 @@ class Union(Subqueryable):
**QUERY_MODIFIERS,
}
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ """
+ Set the LIMIT expression.
+
+ Example:
+ >>> select("1").union(select("1")).limit(1).sql()
+ 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
+
+ Args:
+ expression (str | int | Expression): the SQL code string to parse.
+ This can also be an integer.
+ If a `Limit` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Limit`.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: The limited subqueryable.
+ """
+ return (
+ select("*")
+ .from_(self.subquery(alias="_l_0", copy=copy))
+ .limit(expression, dialect=dialect, copy=False, **opts)
+ )
+
@property
def named_selects(self):
return self.this.unnest().named_selects
@@ -1356,7 +1409,7 @@ class Unnest(UDTF):
class Update(Expression):
arg_types = {
"with": False,
- "this": True,
+ "this": False,
"expressions": True,
"from": False,
"where": False,
@@ -2057,15 +2110,20 @@ class DataType(Expression):
Type.TEXT,
}
- NUMERIC_TYPES = {
+ INTEGER_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
+ }
+
+ FLOAT_TYPES = {
Type.FLOAT,
Type.DOUBLE,
}
+ NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
+
TEMPORAL_TYPES = {
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
@@ -2968,6 +3026,14 @@ class Use(Expression):
pass
+class Merge(Expression):
+ arg_types = {"this": True, "using": True, "on": True, "expressions": True}
+
+
+class When(Func):
+ arg_types = {"this": True, "then": True}
+
+
def _norm_args(expression):
args = {}