summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py258
1 files changed, 169 insertions, 89 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 1691d85..57a2c88 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import datetime
import numbers
import re
+import typing as t
from collections import deque
from copy import deepcopy
from enum import auto
@@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
- ensure_list,
- list_get,
+ ensure_collection,
+ seq_get,
split_num_words,
subclasses,
)
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import Dialect
+
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
or optional (False).
"""
- key = None
+ key = "Expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "type")
+ __slots__ = ("args", "parent", "arg_key", "type", "comment")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
+ self.comment = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(
(
self.key,
- tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
+ tuple(
+ (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
+ ),
)
)
@@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
return field.this
return ""
+ def find_comment(self, key: str) -> str:
+ """
+ Finds the comment that is attached to a specified child node.
+
+ Args:
+ key: the key of the target child node (e.g. "this", "expression", etc).
+
+ Returns:
+ The comment attached to the child node, or the empty string, if it doesn't exist.
+ """
+ field = self.args.get(key)
+ return field.comment if isinstance(field, Expression) else ""
+
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
return self.alias or self.name
def __deepcopy__(self, memo):
- return self.__class__(**deepcopy(self.args))
+ copy = self.__class__(**deepcopy(self.args))
+ copy.comment = self.comment
+ copy.type = self.type
+ return copy
def copy(self):
new = deepcopy(self)
@@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
return
for k, v in self.args.items():
- nodes = ensure_list(v)
-
- for node in nodes:
+ for node in ensure_collection(v):
if isinstance(node, Expression):
yield from node.dfs(self, k, prune)
@@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
if isinstance(item, Expression):
for k, v in item.args.items():
- nodes = ensure_list(v)
-
- for node in nodes:
+ for node in ensure_collection(v):
if isinstance(node, Expression):
queue.append((node, item, k))
@@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self.to_s()
- def sql(self, dialect=None, **opts):
+ def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect)().generate(self, **opts)
- def to_s(self, hide_missing=True, level=0):
+ def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
@@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
- for v in ensure_list(vs)
+ for v in ensure_collection(vs)
if v is not None
)
for k, vs in self.args.items()
}
+ args["comment"] = self.comment
+ args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items())
@@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
pass
-class Annotation(Expression):
- arg_types = {
- "this": True,
- "expression": True,
- }
-
- @property
- def alias(self):
- return self.expression.alias_or_name
-
-
class Cache(Expression):
arg_types = {
"with": False,
@@ -623,6 +635,38 @@ class Describe(Expression):
pass
+class Set(Expression):
+ arg_types = {"expressions": True}
+
+
+class SetItem(Expression):
+ arg_types = {
+ "this": True,
+ "kind": False,
+ "collate": False, # MySQL SET NAMES statement
+ }
+
+
+class Show(Expression):
+ arg_types = {
+ "this": True,
+ "target": False,
+ "offset": False,
+ "limit": False,
+ "like": False,
+ "where": False,
+ "db": False,
+ "full": False,
+ "mutex": False,
+ "query": False,
+ "channel": False,
+ "global": False,
+ "log": False,
+ "position": False,
+ "types": False,
+ }
+
+
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
@@ -864,18 +908,20 @@ class Literal(Condition):
def __eq__(self, other):
return (
- isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
+ isinstance(other, Literal)
+ and self.this == other.this
+ and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
return hash((self.key, self.this, self.args["is_string"]))
@classmethod
- def number(cls, number):
+ def number(cls, number) -> Literal:
return cls(this=str(number), is_string=False)
@classmethod
- def string(cls, string):
+ def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
@@ -1087,7 +1133,7 @@ class Properties(Expression):
}
@classmethod
- def from_dict(cls, properties_dict):
+ def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
@@ -1323,7 +1369,7 @@ class Select(Subqueryable):
**QUERY_MODIFIERS,
}
- def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the FROM expression.
@@ -1356,7 +1402,7 @@ class Select(Subqueryable):
**opts,
)
- def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the GROUP BY expression.
@@ -1392,7 +1438,7 @@ class Select(Subqueryable):
**opts,
)
- def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the ORDER BY expression.
@@ -1425,7 +1471,7 @@ class Select(Subqueryable):
**opts,
)
- def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the SORT BY expression.
@@ -1458,7 +1504,7 @@ class Select(Subqueryable):
**opts,
)
- def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the CLUSTER BY expression.
@@ -1491,7 +1537,7 @@ class Select(Subqueryable):
**opts,
)
- def limit(self, expression, dialect=None, copy=True, **opts):
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the LIMIT expression.
@@ -1522,7 +1568,7 @@ class Select(Subqueryable):
**opts,
)
- def offset(self, expression, dialect=None, copy=True, **opts):
+ def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the OFFSET expression.
@@ -1553,7 +1599,7 @@ class Select(Subqueryable):
**opts,
)
- def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the SELECT expressions.
@@ -1583,7 +1629,7 @@ class Select(Subqueryable):
**opts,
)
- def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the LATERAL expressions.
@@ -1626,7 +1672,7 @@ class Select(Subqueryable):
dialect=None,
copy=True,
**opts,
- ):
+ ) -> Select:
"""
Append to or set the JOIN expressions.
@@ -1672,7 +1718,7 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
- natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural:
join.set("natural", True)
if side:
@@ -1681,12 +1727,12 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
- on = and_(*ensure_list(on), dialect=dialect, **opts)
+ on = and_(*ensure_collection(on), dialect=dialect, **opts)
join.set("on", on)
if using:
join = _apply_list_builder(
- *ensure_list(using),
+ *ensure_collection(using),
instance=join,
arg="using",
append=append,
@@ -1705,7 +1751,7 @@ class Select(Subqueryable):
**opts,
)
- def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the WHERE expressions.
@@ -1737,7 +1783,7 @@ class Select(Subqueryable):
**opts,
)
- def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the HAVING expressions.
@@ -1769,7 +1815,7 @@ class Select(Subqueryable):
**opts,
)
- def distinct(self, distinct=True, copy=True):
+ def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@@ -1788,7 +1834,7 @@ class Select(Subqueryable):
instance.set("distinct", Distinct() if distinct else None)
return instance
- def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
+ def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
"""
Convert this expression to a CREATE TABLE AS statement.
@@ -1826,11 +1872,11 @@ class Select(Subqueryable):
)
@property
- def named_selects(self):
+ def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
@property
- def selects(self):
+ def selects(self) -> t.List[Expression]:
return self.expressions
@@ -1910,12 +1956,16 @@ class Parameter(Expression):
pass
+class SessionParameter(Expression):
+ arg_types = {"this": True, "kind": False}
+
+
class Placeholder(Expression):
arg_types = {"this": False}
class Null(Condition):
- arg_types = {}
+ arg_types: t.Dict[str, t.Any] = {}
class Boolean(Condition):
@@ -1936,6 +1986,7 @@ class DataType(Expression):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
+ VARBINARY = auto()
INT = auto()
TINYINT = auto()
SMALLINT = auto()
@@ -1975,7 +2026,7 @@ class DataType(Expression):
UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod
- def build(cls, dtype, **kwargs):
+ def build(cls, dtype, **kwargs) -> DataType:
return DataType(
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
@@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
pass
+class NullSafeEQ(Binary, Predicate):
+ pass
+
+
+class NullSafeNEQ(Binary, Predicate):
+ pass
+
+
+class Distance(Binary):
+ pass
+
+
class Escape(Binary):
pass
@@ -2101,15 +2164,11 @@ class Is(Binary, Predicate):
pass
-class Like(Binary, Predicate):
- pass
-
-
-class SimilarTo(Binary, Predicate):
- pass
+class Kwarg(Binary):
+ """Kwarg in special functions like func(kwarg => y)."""
-class Distance(Binary):
+class Like(Binary, Predicate):
pass
@@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
pass
+class SimilarTo(Binary, Predicate):
+ pass
+
+
class Sub(Binary):
pass
@@ -2189,7 +2252,13 @@ class Distinct(Expression):
class In(Predicate):
- arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
+ arg_types = {
+ "this": True,
+ "expressions": False,
+ "query": False,
+ "unnest": False,
+ "field": False,
+ }
class TimeUnit(Expression):
@@ -2255,7 +2324,9 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
- raise NotImplementedError("SQL name is only supported by concrete function implementations")
+ raise NotImplementedError(
+ "SQL name is only supported by concrete function implementations"
+ )
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
-class DateTrunc(Func, TimeUnit):
- arg_types = {"this": True, "unit": True, "zone": False}
+class DateTrunc(Func):
+ arg_types = {"this": True, "expression": True, "zone": False}
class DatetimeAdd(Func, TimeUnit):
@@ -2791,6 +2862,10 @@ class Year(Func):
pass
+class Use(Expression):
+ pass
+
+
def _norm_args(expression):
args = {}
@@ -2822,7 +2897,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
-):
+) -> t.Optional[Expression]:
"""Gracefully handle a possible string or expression.
Example:
@@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
-def select(*expressions, dialect=None, **opts):
+def select(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
return Select().select(*expressions, dialect=dialect, **opts)
-def from_(*expressions, dialect=None, **opts):
+def from_(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from a FROM expression.
@@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
-def update(table, properties, where=None, from_=None, dialect=None, **opts):
+def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
"""
Creates an update statement.
@@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
- [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
+ [
+ EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
+ for k, v in properties.items()
+ ],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
@@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
return update
-def delete(table, where=None, dialect=None, **opts):
+def delete(table, where=None, dialect=None, **opts) -> Delete:
"""
Builds a delete statement.
@@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
)
-def condition(expression, dialect=None, **opts):
+def condition(expression, dialect=None, **opts) -> Condition:
"""
Initialize a logical condition expression.
@@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
Returns:
Condition: the expression
"""
- return maybe_parse(
+ return maybe_parse( # type: ignore
expression,
into=Condition,
dialect=dialect,
@@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
)
-def and_(*expressions, dialect=None, **opts):
+def and_(*expressions, dialect=None, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
return _combine(expressions, And, dialect, **opts)
-def or_(*expressions, dialect=None, **opts):
+def or_(*expressions, dialect=None, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
return _combine(expressions, Or, dialect, **opts)
-def not_(expression, dialect=None, **opts):
+def not_(expression, dialect=None, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
return Not(this=_wrap_operator(this))
-def paren(expression):
+def paren(expression) -> Paren:
return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
-def to_identifier(alias, quoted=None):
+def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
if alias is None:
return None
if isinstance(alias, Identifier):
@@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
return identifier
-def to_table(sql_path: str, **kwargs) -> Table:
+def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
-
If a table is passed in then that table is returned.
Args:
- sql_path(str|Table): `[catalog].[schema].[table]` string
+ sql_path: a `[catalog].[schema].[table]` string.
+
Returns:
- Table: A table expression
+ A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
return sql_path
@@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
return Select().from_(expression, dialect=dialect, **opts)
-def column(col, table=None, quoted=None):
+def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
Args:
@@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
)
-def table_(table, db=None, catalog=None, quoted=None, alias=None):
+def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
Args:
@@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
)
-def values(values, alias=None):
+def values(values, alias=None) -> Values:
"""Build VALUES statement.
Example:
@@ -3449,7 +3527,7 @@ def values(values, alias=None):
)
-def convert(value):
+def convert(value) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
@@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
for cn in child_nodes:
if isinstance(cn, Expression):
- cns = ensure_list(fun(cn))
- for child_node in cns:
+ for child_node in ensure_collection(fun(cn)):
new_child_nodes.append(child_node)
child_node.parent = expression
child_node.arg_key = k
else:
new_child_nodes.append(cn)
- expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
+ expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
def column_table_names(expression):
@@ -3529,7 +3606,7 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
-def table_name(table):
+def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
@@ -3546,6 +3623,9 @@ def table_name(table):
table = maybe_parse(table, into=Table)
+ if not table:
+ raise ValueError(f"Cannot parse {table}")
+
return ".".join(
part
for part in (