summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py169
1 files changed, 113 insertions, 56 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 7acc63d..b983bf9 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -47,10 +47,7 @@ class Expression(metaclass=_Expression):
return hash(
(
self.key,
- tuple(
- (k, tuple(v) if isinstance(v, list) else v)
- for k, v in _norm_args(self).items()
- ),
+ tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
)
)
@@ -116,9 +113,22 @@ class Expression(metaclass=_Expression):
item.parent = parent
return new
+ def append(self, arg_key, value):
+ """
+ Appends value to arg_key if it's a list or sets it as a new list.
+
+ Args:
+ arg_key (str): name of the list expression arg
+ value (Any): value to append to the list
+ """
+ if not isinstance(self.args.get(arg_key), list):
+ self.args[arg_key] = []
+ self.args[arg_key].append(value)
+ self._set_parent(arg_key, value)
+
def set(self, arg_key, value):
"""
- Sets `arg` to `value`.
+ Sets `arg_key` to `value`.
Args:
arg_key (str): name of the expression arg
@@ -267,6 +277,14 @@ class Expression(metaclass=_Expression):
expression = expression.this
return expression
+ def unalias(self):
+ """
+ Returns the inner expression if this is an Alias.
+ """
+ if isinstance(self, Alias):
+ return self.this
+ return self
+
def unnest_operands(self):
"""
Returns unnested operands as a tuple.
@@ -279,9 +297,7 @@ class Expression(metaclass=_Expression):
A AND B AND C -> [A, B, C]
"""
- for node, _, _ in self.dfs(
- prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
- ):
+ for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)):
if not isinstance(node, self.__class__):
yield node.unnest() if unnest else node
@@ -314,9 +330,7 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
- v.to_s(hide_missing=hide_missing, level=level + 1)
- if hasattr(v, "to_s")
- else str(v)
+ v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_list(vs)
if v is not None
)
@@ -354,9 +368,7 @@ class Expression(metaclass=_Expression):
new_node.parent = node.parent
return new_node
- replace_children(
- new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
- )
+ replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs))
return new_node
def replace(self, expression):
@@ -546,6 +558,10 @@ class BitString(Condition):
pass
+class HexString(Condition):
+ pass
+
+
class Column(Condition):
arg_types = {"this": True, "table": False}
@@ -566,35 +582,44 @@ class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
-class AutoIncrementColumnConstraint(Expression):
+class ColumnConstraintKind(Expression):
pass
-class CheckColumnConstraint(Expression):
+class AutoIncrementColumnConstraint(ColumnConstraintKind):
pass
-class CollateColumnConstraint(Expression):
+class CheckColumnConstraint(ColumnConstraintKind):
pass
-class CommentColumnConstraint(Expression):
+class CollateColumnConstraint(ColumnConstraintKind):
pass
-class DefaultColumnConstraint(Expression):
+class CommentColumnConstraint(ColumnConstraintKind):
pass
-class NotNullColumnConstraint(Expression):
+class DefaultColumnConstraint(ColumnConstraintKind):
pass
-class PrimaryKeyColumnConstraint(Expression):
+class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
+ # this: True -> ALWAYS, this: False -> BY DEFAULT
+ arg_types = {"this": True, "expression": False}
+
+
+class NotNullColumnConstraint(ColumnConstraintKind):
pass
-class UniqueColumnConstraint(Expression):
+class PrimaryKeyColumnConstraint(ColumnConstraintKind):
+ pass
+
+
+class UniqueColumnConstraint(ColumnConstraintKind):
pass
@@ -651,9 +676,7 @@ class Identifier(Expression):
return bool(self.args.get("quoted"))
def __eq__(self, other):
- return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(
- other.this
- )
+ return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this)
def __hash__(self):
return hash((self.key, self.this.lower()))
@@ -709,9 +732,7 @@ class Literal(Condition):
def __eq__(self, other):
return (
- isinstance(other, Literal)
- and self.this == other.this
- and self.args["is_string"] == other.args["is_string"]
+ isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
@@ -733,6 +754,7 @@ class Join(Expression):
"side": False,
"kind": False,
"using": False,
+ "natural": False,
}
@property
@@ -743,6 +765,10 @@ class Join(Expression):
def side(self):
return self.text("side").upper()
+ @property
+ def alias_or_name(self):
+ return self.this.alias_or_name
+
def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
"""
Append to or set the ON expressions.
@@ -873,10 +899,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
-class Table(Expression):
- arg_types = {"this": True, "db": False, "catalog": False}
-
-
class Tuple(Expression):
arg_types = {"expressions": False}
@@ -986,6 +1008,16 @@ QUERY_MODIFIERS = {
}
+class Table(Expression):
+ arg_types = {
+ "this": True,
+ "db": False,
+ "catalog": False,
+ "laterals": False,
+ "joins": False,
+ }
+
+
class Union(Subqueryable, Expression):
arg_types = {
"with": False,
@@ -1396,7 +1428,9 @@ class Select(Subqueryable, Expression):
join.this.replace(join.this.subquery())
if join_type:
- side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ if natural:
+ join.set("natural", True)
if side:
join.set("side", side.text)
if kind:
@@ -1529,10 +1563,7 @@ class Select(Subqueryable, Expression):
properties_expression = None
if properties:
properties_str = " ".join(
- [
- f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
- for k, v in properties.items()
- ]
+ [f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in properties.items()]
)
properties_expression = maybe_parse(
properties_str,
@@ -1654,6 +1685,7 @@ class DataType(Expression):
DECIMAL = auto()
BOOLEAN = auto()
JSON = auto()
+ INTERVAL = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
DATE = auto()
@@ -1662,15 +1694,19 @@ class DataType(Expression):
MAP = auto()
UUID = auto()
GEOGRAPHY = auto()
+ GEOMETRY = auto()
STRUCT = auto()
NULLABLE = auto()
+ HLLSKETCH = auto()
+ SUPER = auto()
+ SERIAL = auto()
+ SMALLSERIAL = auto()
+ BIGSERIAL = auto()
@classmethod
def build(cls, dtype, **kwargs):
return DataType(
- this=dtype
- if isinstance(dtype, DataType.Type)
- else DataType.Type[dtype.upper()],
+ this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
)
@@ -1798,6 +1834,14 @@ class Like(Binary, Predicate):
pass
+class SimilarTo(Binary, Predicate):
+ pass
+
+
+class Distance(Binary):
+ pass
+
+
class LT(Binary, Predicate):
pass
@@ -1899,6 +1943,10 @@ class IgnoreNulls(Expression):
pass
+class RespectNulls(Expression):
+ pass
+
+
# Functions
class Func(Condition):
"""
@@ -1924,9 +1972,7 @@ class Func(Condition):
all_arg_keys = list(cls.arg_types)
# If this function supports variable length argument treat the last argument as such.
- non_var_len_arg_keys = (
- all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
- )
+ non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
args_dict = {}
arg_idx = 0
@@ -1944,9 +1990,7 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
- raise NotImplementedError(
- "SQL name is only supported by concrete function implementations"
- )
+ raise NotImplementedError("SQL name is only supported by concrete function implementations")
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2178,6 +2222,10 @@ class Greatest(Func):
is_var_len_args = True
+class GroupConcat(Func):
+ arg_types = {"this": True, "separator": False}
+
+
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -2274,6 +2322,10 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
+class ApproxQuantile(Quantile):
+ pass
+
+
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@@ -2306,8 +2358,10 @@ class Split(Func):
arg_types = {"this": True, "expression": True}
+# Start may be omitted in the case of postgres
+# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
class Substring(Func):
- arg_types = {"this": True, "start": True, "length": False}
+ arg_types = {"this": True, "start": False, "length": False}
class StrPosition(Func):
@@ -2379,6 +2433,15 @@ class TimeStrToUnix(Func):
pass
+class Trim(Func):
+ arg_types = {
+ "this": True,
+ "position": False,
+ "expression": False,
+ "collation": False,
+ }
+
+
class TsOrDsAdd(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
@@ -2455,9 +2518,7 @@ def _all_functions():
obj
for _, obj in inspect.getmembers(
sys.modules[__name__],
- lambda obj: inspect.isclass(obj)
- and issubclass(obj, Func)
- and obj not in (AggFunc, Anonymous, Func),
+ lambda obj: inspect.isclass(obj) and issubclass(obj, Func) and obj not in (AggFunc, Anonymous, Func),
)
]
@@ -2633,9 +2694,7 @@ def _apply_conjunction_builder(
def _combine(expressions, operator, dialect=None, **opts):
- expressions = [
- condition(expression, dialect=dialect, **opts) for expression in expressions
- ]
+ expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions]
this = expressions[0]
if expressions[1:]:
this = _wrap_operator(this)
@@ -2809,9 +2868,7 @@ def to_identifier(alias, quoted=None):
quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
identifier = Identifier(this=alias, quoted=quoted)
else:
- raise ValueError(
- f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
- )
+ raise ValueError(f"Alias needs to be a string or an Identifier, got: {alias.__class__}")
return identifier