summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py100
1 files changed, 78 insertions, 22 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 9a6b440..f8e9fee 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -67,8 +67,9 @@ class Expression(metaclass=_Expression):
uses to refer to it.
comments: a list of comments that are associated with a given expression. This is used in
order to preserve comments when transpiling SQL code.
- _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
+ type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
optimizer, in order to enable some transformations that require type information.
+ meta: a dictionary that can be used to store useful metadata for a given expression.
Example:
>>> class Foo(Expression):
@@ -767,7 +768,7 @@ class Condition(Expression):
**opts,
) -> In:
return In(
- this=_maybe_copy(self, copy),
+ this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
@@ -781,7 +782,7 @@ class Condition(Expression):
def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between:
return Between(
- this=_maybe_copy(self, copy),
+ this=maybe_copy(self, copy),
low=convert(low, copy=copy, **opts),
high=convert(high, copy=copy, **opts),
)
@@ -990,7 +991,28 @@ class Uncache(Expression):
arg_types = {"this": True, "exists": False}
-class Create(Expression):
+class DDL(Expression):
+ @property
+ def ctes(self):
+ with_ = self.args.get("with")
+ if not with_:
+ return []
+ return with_.expressions
+
+ @property
+ def named_selects(self) -> t.List[str]:
+ if isinstance(self.expression, Subqueryable):
+ return self.expression.named_selects
+ return []
+
+ @property
+ def selects(self) -> t.List[Expression]:
+ if isinstance(self.expression, Subqueryable):
+ return self.expression.selects
+ return []
+
+
+class Create(DDL):
arg_types = {
"with": False,
"this": True,
@@ -1206,6 +1228,19 @@ class MergeTreeTTL(Expression):
}
+# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
+class IndexConstraintOption(Expression):
+ arg_types = {
+ "key_block_size": False,
+ "using": False,
+ "parser": False,
+ "comment": False,
+ "visible": False,
+ "engine_attr": False,
+ "secondary_engine_attr": False,
+ }
+
+
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@@ -1272,6 +1307,11 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
}
+# https://dev.mysql.com/doc/refman/8.0/en/create-table.html
+class IndexColumnConstraint(ColumnConstraintKind):
+ arg_types = {"this": False, "schema": True, "kind": False, "type": False, "options": False}
+
+
class InlineLengthColumnConstraint(ColumnConstraintKind):
pass
@@ -1496,7 +1536,7 @@ class JoinHint(Expression):
class Identifier(Expression):
- arg_types = {"this": True, "quoted": False}
+ arg_types = {"this": True, "quoted": False, "global": False, "temporary": False}
@property
def quoted(self) -> bool:
@@ -1525,7 +1565,7 @@ class Index(Expression):
}
-class Insert(Expression):
+class Insert(DDL):
arg_types = {
"with": False,
"this": True,
@@ -1892,6 +1932,10 @@ class EngineProperty(Property):
arg_types = {"this": True}
+class HeapProperty(Property):
+ arg_types = {}
+
+
class ToTableProperty(Property):
arg_types = {"this": True}
@@ -2182,7 +2226,7 @@ class Tuple(Expression):
**opts,
) -> In:
return In(
- this=_maybe_copy(self, copy),
+ this=maybe_copy(self, copy),
expressions=[convert(e, copy=copy) for e in expressions],
query=maybe_parse(query, copy=copy, **opts) if query else None,
unnest=Unnest(
@@ -2212,7 +2256,7 @@ class Subqueryable(Unionable):
Returns:
Alias: the subquery
"""
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
if not isinstance(alias, Expression):
alias = TableAlias(this=to_identifier(alias)) if alias else None
@@ -2865,7 +2909,7 @@ class Select(Subqueryable):
self,
expression: ExpOrStr,
on: t.Optional[ExpOrStr] = None,
- using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None,
+ using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None,
append: bool = True,
join_type: t.Optional[str] = None,
join_alias: t.Optional[Identifier | str] = None,
@@ -2943,6 +2987,7 @@ class Select(Subqueryable):
arg="using",
append=append,
copy=copy,
+ into=Identifier,
**opts,
)
@@ -3092,7 +3137,7 @@ class Select(Subqueryable):
Returns:
Select: the modified expression.
"""
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None
instance.set("distinct", Distinct(on=on) if distinct else None)
return instance
@@ -3123,7 +3168,7 @@ class Select(Subqueryable):
Returns:
The new Create expression.
"""
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
table_expression = maybe_parse(
table,
into=Table,
@@ -3159,7 +3204,7 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
- inst = _maybe_copy(self, copy)
+ inst = maybe_copy(self, copy)
inst.set("locks", [Lock(update=update)])
return inst
@@ -3181,7 +3226,7 @@ class Select(Subqueryable):
Returns:
The modified expression.
"""
- inst = _maybe_copy(self, copy)
+ inst = maybe_copy(self, copy)
inst.set(
"hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints])
)
@@ -3376,6 +3421,8 @@ class DataType(Expression):
HSTORE = auto()
IMAGE = auto()
INET = auto()
+ IPADDRESS = auto()
+ IPPREFIX = auto()
INT = auto()
INT128 = auto()
INT256 = auto()
@@ -3987,7 +4034,7 @@ class Case(Func):
arg_types = {"this": False, "ifs": True, "default": False}
def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case:
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
instance.append(
"ifs",
If(
@@ -3998,7 +4045,7 @@ class Case(Func):
return instance
def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case:
- instance = _maybe_copy(self, copy)
+ instance = maybe_copy(self, copy)
instance.set("default", maybe_parse(condition, copy=copy, **opts))
return instance
@@ -4263,6 +4310,10 @@ class Initcap(Func):
arg_types = {"this": True, "expression": False}
+class IsNan(Func):
+ _sql_names = ["IS_NAN", "ISNAN"]
+
+
class JSONKeyValue(Expression):
arg_types = {"this": True, "expression": True}
@@ -4549,6 +4600,11 @@ class StandardHash(Func):
arg_types = {"this": True, "expression": False}
+class StartsWith(Func):
+ _sql_names = ["STARTS_WITH", "STARTSWITH"]
+ arg_types = {"this": True, "expression": True}
+
+
class StrPosition(Func):
arg_types = {
"this": True,
@@ -4804,7 +4860,7 @@ def maybe_parse(
return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
-def _maybe_copy(instance: E, copy: bool = True) -> E:
+def maybe_copy(instance: E, copy: bool = True) -> E:
return instance.copy() if copy else instance
@@ -4824,7 +4880,7 @@ def _apply_builder(
):
if _is_wrong_expression(expression, into):
expression = into(this=expression)
- instance = _maybe_copy(instance, copy)
+ instance = maybe_copy(instance, copy)
expression = maybe_parse(
sql_or_expression=expression,
prefix=prefix,
@@ -4848,7 +4904,7 @@ def _apply_child_list_builder(
properties=None,
**opts,
):
- instance = _maybe_copy(instance, copy)
+ instance = maybe_copy(instance, copy)
parsed = []
for expression in expressions:
if expression is not None:
@@ -4887,7 +4943,7 @@ def _apply_list_builder(
dialect=None,
**opts,
):
- inst = _maybe_copy(instance, copy)
+ inst = maybe_copy(instance, copy)
expressions = [
maybe_parse(
@@ -4923,7 +4979,7 @@ def _apply_conjunction_builder(
if not expressions:
return instance
- inst = _maybe_copy(instance, copy)
+ inst = maybe_copy(instance, copy)
existing = inst.args.get(arg)
if append and existing is not None:
@@ -5398,7 +5454,7 @@ def to_identifier(name, quoted=None, copy=True):
return None
if isinstance(name, Identifier):
- identifier = _maybe_copy(name, copy)
+ identifier = maybe_copy(name, copy)
elif isinstance(name, str):
identifier = Identifier(
this=name,
@@ -5735,7 +5791,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression:
Expression: the equivalent expression object.
"""
if isinstance(value, Expression):
- return _maybe_copy(value, copy)
+ return maybe_copy(value, copy)
if isinstance(value, str):
return Literal.string(value)
if isinstance(value, bool):