summaryrefslogtreecommitdiffstats
path: root/sqlglot/expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r--sqlglot/expressions.py410
1 files changed, 340 insertions, 70 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index be99fe2..f9751ca 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,5 +1,12 @@
"""
-.. include:: ../pdoc/docs/expressions.md
+## Expressions
+
+Every AST node in SQLGlot is represented by a subclass of `Expression`.
+
+This module contains the implementation of all supported `Expression` types. Additionally,
+it exposes a number of helper functions, which are mainly used to programmatically build
+SQL expressions, such as `sqlglot.expressions.select`.
+----
"""
from __future__ import annotations
@@ -27,35 +34,66 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
+ IntoType = t.Union[
+ str,
+ t.Type[Expression],
+ t.Collection[t.Union[str, t.Type[Expression]]],
+ ]
+
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
+
+ # When an Expression class is created, its key is automatically set to be
+ # the lowercase version of the class' name.
klass.key = clsname.lower()
+
+ # This is so that docstrings are not inherited in pdoc
+ klass.__doc__ = klass.__doc__ or ""
+
return klass
class Expression(metaclass=_Expression):
"""
- The base class for all expressions in a syntax tree.
+ The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary
+ context, such as its child expressions, their names (arg keys), and whether a given child expression
+ is optional or not.
Attributes:
- arg_types (dict): determines arguments supported by this expression.
- The key in a dictionary defines a unique key of an argument using
- which the argument's value can be retrieved. The value is a boolean
- flag which indicates whether the argument's value is required (True)
- or optional (False).
+ key: a unique key for each class in the Expression hierarchy. This is useful for hashing
+ and representing expressions as strings.
+ arg_types: determines what arguments (child nodes) are supported by an expression. It
+ maps arg keys to booleans that indicate whether the corresponding args are optional.
+
+ Example:
+ >>> class Foo(Expression):
+ ... arg_types = {"this": True, "expression": False}
+
+ The above definition informs us that Foo is an Expression that requires an argument called
+ "this" and may also optionally receive an argument called "expression".
+
+ Args:
+ args: a mapping used for retrieving the arguments of an expression, given their arg keys.
+ parent: a reference to the parent expression (or None, in case of root expressions).
+ arg_key: the arg key an expression is associated with, i.e. the name its parent expression
+ uses to refer to it.
+ comments: a list of comments that are associated with a given expression. This is used in
+ order to preserve comments when transpiling SQL code.
+ _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the
+ optimizer, in order to enable some transformations that require type information.
"""
- key = "Expression"
+ key = "expression"
arg_types = {"this": True}
__slots__ = ("args", "parent", "arg_key", "comments", "_type")
- def __init__(self, **args):
- self.args = args
- self.parent = None
- self.arg_key = None
- self.comments = None
+ def __init__(self, **args: t.Any):
+ self.args: t.Dict[str, t.Any] = args
+ self.parent: t.Optional[Expression] = None
+ self.arg_key: t.Optional[str] = None
+ self.comments: t.Optional[t.List[str]] = None
self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
@@ -76,17 +114,30 @@ class Expression(metaclass=_Expression):
@property
def this(self):
+ """
+ Retrieves the argument with key "this".
+ """
return self.args.get("this")
@property
def expression(self):
+ """
+ Retrieves the argument with key "expression".
+ """
return self.args.get("expression")
@property
def expressions(self):
+ """
+ Retrieves the argument with key "expressions".
+ """
return self.args.get("expressions") or []
def text(self, key):
+ """
+ Returns a textual representation of the argument corresponding to "key". This can only be used
+ for args that are strings or leaf Expression instances, such as identifiers and literals.
+ """
field = self.args.get(key)
if isinstance(field, str):
return field
@@ -96,14 +147,23 @@ class Expression(metaclass=_Expression):
@property
def is_string(self):
+ """
+ Checks whether a Literal expression is a string.
+ """
return isinstance(self, Literal) and self.args["is_string"]
@property
def is_number(self):
+ """
+ Checks whether a Literal expression is a number.
+ """
return isinstance(self, Literal) and not self.args["is_string"]
@property
def is_int(self):
+ """
+ Checks whether a Literal expression is an integer.
+ """
if self.is_number:
try:
int(self.name)
@@ -114,6 +174,9 @@ class Expression(metaclass=_Expression):
@property
def alias(self):
+ """
+ Returns the alias of the expression, or an empty string if it's not aliased.
+ """
if isinstance(self.args.get("alias"), TableAlias):
return self.args["alias"].name
return self.text("alias")
@@ -129,6 +192,24 @@ class Expression(metaclass=_Expression):
return self.alias or self.name
@property
+ def output_name(self):
+ """
+ Name of the output column if this expression is a selection.
+
+ If the Expression has no output name, an empty string is returned.
+
+ Example:
+ >>> from sqlglot import parse_one
+ >>> parse_one("SELECT a").expressions[0].output_name
+ 'a'
+ >>> parse_one("SELECT b AS c").expressions[0].output_name
+ 'c'
+ >>> parse_one("SELECT 1 + 2").expressions[0].output_name
+ ''
+ """
+ return ""
+
+ @property
def type(self) -> t.Optional[DataType]:
return self._type
@@ -145,6 +226,9 @@ class Expression(metaclass=_Expression):
return copy
def copy(self):
+ """
+ Returns a deep copy of the expression.
+ """
new = deepcopy(self)
for item, parent, _ in new.bfs():
if isinstance(item, Expression) and parent:
@@ -169,7 +253,7 @@ class Expression(metaclass=_Expression):
Sets `arg_key` to `value`.
Args:
- arg_key (str): name of the expression arg
+ arg_key (str): name of the expression arg.
value: value to set the arg to.
"""
self.args[arg_key] = value
@@ -203,8 +287,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
- the node which matches the criteria or None if no node matching
- the criteria was found.
+ The node which matches the criteria or None if no such node was found.
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
@@ -217,7 +300,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
- the generator object.
+ The generator object.
"""
for expression, _, _ in self.walk(bfs=bfs):
if isinstance(expression, expression_types):
@@ -231,7 +314,7 @@ class Expression(metaclass=_Expression):
expression_types (type): the expression type(s) to match.
Returns:
- the parent node
+ The parent node.
"""
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
@@ -269,7 +352,7 @@ class Expression(metaclass=_Expression):
the DFS (Depth-first) order.
Returns:
- the generator object.
+ The generator object.
"""
parent = parent or self.parent
yield self, parent, key
@@ -287,7 +370,7 @@ class Expression(metaclass=_Expression):
the BFS (Breadth-first) order.
Returns:
- the generator object.
+ The generator object.
"""
queue = deque([(self, self.parent, None)])
@@ -341,32 +424,33 @@ class Expression(metaclass=_Expression):
return self.sql()
def __repr__(self):
- return self.to_s()
+ return self._to_s()
def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
- Args
- dialect (str): the dialect of the output SQL string
- (eg. "spark", "hive", "presto", "mysql").
- opts (dict): other :class:`~sqlglot.generator.Generator` options.
+ Args:
+ dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql").
+ opts: other `sqlglot.generator.Generator` options.
- Returns
- the SQL string.
+ Returns:
+ The SQL string.
"""
from sqlglot.dialects import Dialect
return Dialect.get_or_raise(dialect)().generate(self, **opts)
- def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
+ def _to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
args: t.Dict[str, t.Any] = {
k: ", ".join(
- v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
+ v._to_s(hide_missing=hide_missing, level=level + 1)
+ if hasattr(v, "_to_s")
+ else str(v)
for v in ensure_collection(vs)
if v is not None
)
@@ -394,7 +478,7 @@ class Expression(metaclass=_Expression):
modified in place.
Returns:
- the transformed tree.
+ The transformed tree.
"""
node = self.copy() if copy else self
new_node = fun(node, *args, **kwargs)
@@ -423,8 +507,8 @@ class Expression(metaclass=_Expression):
Args:
expression (Expression|None): new node
- Returns :
- the new expression or expressions
+ Returns:
+ The new expression or expressions.
"""
if not self.parent:
return expression
@@ -458,6 +542,40 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_)
return self
+ def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]:
+ """
+ Checks if this expression is valid (e.g. all mandatory args are set).
+
+ Args:
+ args: a sequence of values that were used to instantiate a Func expression. This is used
+ to check that the provided arguments don't exceed the function argument limit.
+
+ Returns:
+ A list of error messages for all possible errors that were found.
+ """
+ errors: t.List[str] = []
+
+ for k in self.args:
+ if k not in self.arg_types:
+ errors.append(f"Unexpected keyword: '{k}' for {self.__class__}")
+ for k, mandatory in self.arg_types.items():
+ v = self.args.get(k)
+ if mandatory and (v is None or (isinstance(v, list) and not v)):
+ errors.append(f"Required keyword: '{k}' missing for {self.__class__}")
+
+ if (
+ args
+ and isinstance(self, Func)
+ and len(args) > len(self.arg_types)
+ and not self.is_var_len_args
+ ):
+ errors.append(
+ f"The number of provided arguments ({len(args)}) is greater than "
+ f"the maximum number of supported arguments ({len(self.arg_types)})"
+ )
+
+ return errors
+
def dump(self):
"""
Dump this Expression to a JSON-serializable dict.
@@ -552,7 +670,7 @@ class DerivedTable(Expression):
@property
def named_selects(self):
- return [select.alias_or_name for select in self.selects]
+ return [select.output_name for select in self.selects]
class Unionable(Expression):
@@ -654,6 +772,7 @@ class Create(Expression):
"no_primary_index": False,
"indexes": False,
"no_schema_binding": False,
+ "begin": False,
}
@@ -696,7 +815,7 @@ class Show(Expression):
class UserDefinedFunction(Expression):
- arg_types = {"this": True, "expressions": False}
+ arg_types = {"this": True, "expressions": False, "wrapped": False}
class UserDefinedFunctionKwarg(Expression):
@@ -750,6 +869,10 @@ class Column(Condition):
def table(self):
return self.text("table")
+ @property
+ def output_name(self):
+ return self.name
+
class ColumnDef(Expression):
arg_types = {
@@ -865,6 +988,10 @@ class ForeignKey(Expression):
}
+class PrimaryKey(Expression):
+ arg_types = {"expressions": True, "options": False}
+
+
class Unique(Expression):
arg_types = {"expressions": True}
@@ -904,6 +1031,10 @@ class Identifier(Expression):
def __hash__(self):
return hash((self.key, self.this.lower()))
+ @property
+ def output_name(self):
+ return self.name
+
class Index(Expression):
arg_types = {
@@ -996,6 +1127,10 @@ class Literal(Condition):
def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
+ @property
+ def output_name(self):
+ return self.name
+
class Join(Expression):
arg_types = {
@@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property):
class ReturnsProperty(Property):
- arg_types = {"this": True, "is_table": False}
+ arg_types = {"this": True, "is_table": False, "table": False}
class LanguageProperty(Property):
@@ -1262,8 +1397,13 @@ class Qualify(Expression):
pass
+# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql
+class Return(Expression):
+ pass
+
+
class Reference(Expression):
- arg_types = {"this": True, "expressions": True}
+ arg_types = {"this": True, "expressions": False, "options": False}
class Tuple(Expression):
@@ -1397,6 +1537,16 @@ class Table(Expression):
"joins": False,
"pivots": False,
"hints": False,
+ "system_time": False,
+ }
+
+
+# See the TSQL "Querying data in a system-versioned temporal table" page
+class SystemTime(Expression):
+ arg_types = {
+ "this": False,
+ "expression": False,
+ "kind": True,
}
@@ -2027,7 +2177,7 @@ class Select(Subqueryable):
@property
def named_selects(self) -> t.List[str]:
- return [e.alias_or_name for e in self.expressions if e.alias_or_name]
+ return [e.output_name for e in self.expressions if e.alias_or_name]
@property
def selects(self) -> t.List[Expression]:
@@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable):
expression = expression.this
return expression
+ @property
+ def output_name(self):
+ return self.alias
+
class TableSample(Expression):
arg_types = {
@@ -2066,6 +2220,16 @@ class TableSample(Expression):
}
+class Tag(Expression):
+ """Tags are used for generating arbitrary sql like SELECT <span>x</span>."""
+
+ arg_types = {
+ "this": False,
+ "prefix": False,
+ "postfix": False,
+ }
+
+
class Pivot(Expression):
arg_types = {
"this": False,
@@ -2106,6 +2270,10 @@ class Star(Expression):
def name(self):
return "*"
+ @property
+ def output_name(self):
+ return self.name
+
class Parameter(Expression):
pass
@@ -2143,6 +2311,8 @@ class DataType(Expression):
TEXT = auto()
MEDIUMTEXT = auto()
LONGTEXT = auto()
+ MEDIUMBLOB = auto()
+ LONGBLOB = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
@@ -2282,11 +2452,11 @@ class Rollback(Expression):
class AlterTable(Expression):
- arg_types = {
- "this": True,
- "actions": True,
- "exists": False,
- }
+ arg_types = {"this": True, "actions": True, "exists": False}
+
+
+class AddConstraint(Expression):
+ arg_types = {"this": False, "expression": False, "enforced": False}
# Binary expressions like (ADD a b)
@@ -2456,6 +2626,10 @@ class Neg(Unary):
class Alias(Expression):
arg_types = {"this": True, "alias": False}
+ @property
+ def output_name(self):
+ return self.alias
+
class Aliases(Expression):
arg_types = {"this": True, "expressions": True}
@@ -2523,16 +2697,13 @@ class Func(Condition):
"""
The base class for all function expressions.
- Attributes
- is_var_len_args (bool): if set to True the last argument defined in
- arg_types will be treated as a variable length argument and the
- argument's value will be stored as a list.
- _sql_names (list): determines the SQL name (1st item in the list) and
- aliases (subsequent items) for this function expression. These
- values are used to map this node to a name during parsing as well
- as to provide the function's name during SQL string generation. By
- default the SQL name is set to the expression's class name transformed
- to snake case.
+ Attributes:
+ is_var_len_args (bool): if set to True the last argument defined in arg_types will be
+ treated as a variable length argument and the argument's value will be stored as a list.
+ _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items)
+ for this function expression. These values are used to map this node to a name during parsing
+ as well as to provide the function's name during SQL string generation. By default the SQL
+ name is set to the expression's class name transformed to snake case.
"""
is_var_len_args = False
@@ -2558,7 +2729,7 @@ class Func(Condition):
raise NotImplementedError(
"SQL name is only supported by concrete function implementations"
)
- if not hasattr(cls, "_sql_names"):
+ if "_sql_names" not in cls.__dict__:
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2658,6 +2829,10 @@ class Cast(Func):
def to(self):
return self.args["to"]
+ @property
+ def output_name(self):
+ return self.name
+
class Collate(Binary):
pass
@@ -2956,6 +3131,14 @@ class Pow(Func):
_sql_names = ["POWER", "POW"]
+class PercentileCont(AggFunc):
+ pass
+
+
+class PercentileDisc(AggFunc):
+ pass
+
+
class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
@@ -3213,12 +3396,13 @@ def _norm_arg(arg):
ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func))
+# Helpers
def maybe_parse(
- sql_or_expression,
+ sql_or_expression: str | Expression,
*,
- into=None,
- dialect=None,
- prefix=None,
+ into: t.Optional[IntoType] = None,
+ dialect: t.Optional[str] = None,
+ prefix: t.Optional[str] = None,
**opts,
) -> Expression:
"""Gracefully handle a possible string or expression.
@@ -3230,11 +3414,11 @@ def maybe_parse(
(IDENTIFIER this: x, quoted: False)
Args:
- sql_or_expression (str | Expression): the SQL code string or an expression
- into (Expression): the SQLGlot Expression to parse into
- dialect (str): the dialect used to parse the input expressions (in the case that an
+ sql_or_expression: the SQL code string or an expression
+ into: the SQLGlot Expression to parse into
+ dialect: the dialect used to parse the input expressions (in the case that an
input expression is a SQL string).
- prefix (str): a string to prefix the sql with before it gets parsed
+ prefix: a string to prefix the sql with before it gets parsed
(automatically includes a space)
**opts: other options to use to parse the input expressions (again, in the case
that an input expression is a SQL string).
@@ -3993,7 +4177,7 @@ def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
- table (exp.Table | str): Table expression node or string.
+ table (exp.Table | str): table expression node or string.
Examples:
>>> from sqlglot import exp, parse_one
@@ -4001,7 +4185,7 @@ def table_name(table) -> str:
'a.b.c'
Returns:
- str: the table name
+ The table name.
"""
table = maybe_parse(table, into=Table)
@@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping):
"""Replace all tables in expression according to the mapping.
Args:
- expression (sqlglot.Expression): Expression node to be transformed and replaced
- mapping (Dict[str, str]): Mapping of table names
+ expression (sqlglot.Expression): expression node to be transformed and replaced.
+ mapping (Dict[str, str]): mapping of table names.
Examples:
>>> from sqlglot import exp, parse_one
@@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping):
'SELECT * FROM c'
Returns:
- The mapped expression
+ The mapped expression.
"""
def _replace_tables(node):
@@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs):
"""Replace placeholders in an expression.
Args:
- expression (sqlglot.Expression): Expression node to be transformed and replaced
- args: Positional names that will substitute unnamed placeholders in the given order
- kwargs: Keyword arguments that will substitute named placeholders
+ expression (sqlglot.Expression): expression node to be transformed and replaced.
+ args: positional names that will substitute unnamed placeholders in the given order.
+ kwargs: keyword arguments that will substitute named placeholders.
Examples:
>>> from sqlglot import exp, parse_one
@@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs):
'SELECT * FROM foo WHERE a = b'
Returns:
- The mapped expression
+ The mapped expression.
"""
def _replace_placeholders(node, args, **kwargs):
@@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
+def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression:
+ """Transforms an expression by expanding all referenced sources into subqueries.
+
+ Examples:
+ >>> from sqlglot import parse_one
+ >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql()
+ 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */'
+
+ Args:
+ expression: The expression to expand.
+ sources: A dictionary of name to Subqueryables.
+ copy: Whether or not to copy the expression during transformation. Defaults to True.
+
+ Returns:
+ The transformed expression.
+ """
+
+ def _expand(node: Expression):
+ if isinstance(node, Table):
+ name = table_name(node)
+ source = sources.get(name)
+ if source:
+ subquery = source.subquery(node.alias or name)
+ subquery.comments = [f"source: {name}"]
+ return subquery
+ return node
+
+ return expression.transform(_expand, copy=copy)
+
+
+def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func:
+ """
+ Returns a Func expression.
+
+ Examples:
+ >>> func("abs", 5).sql()
+ 'ABS(5)'
+
+ >>> func("cast", this=5, to=DataType.build("DOUBLE")).sql()
+ 'CAST(5 AS DOUBLE)'
+
+ Args:
+ name: the name of the function to build.
+ args: the args used to instantiate the function of interest.
+ dialect: the source dialect.
+ kwargs: the kwargs used to instantiate the function of interest.
+
+ Note:
+ The arguments `args` and `kwargs` are mutually exclusive.
+
+ Returns:
+ An instance of the function of interest, or an anonymous function, if `name` doesn't
+ correspond to an existing `sqlglot.expressions.Func` class.
+ """
+ if args and kwargs:
+ raise ValueError("Can't use both args and kwargs to instantiate a function.")
+
+ from sqlglot.dialects.dialect import Dialect
+
+ args = tuple(convert(arg) for arg in args)
+ kwargs = {key: convert(value) for key, value in kwargs.items()}
+
+ parser = Dialect.get_or_raise(dialect)().parser()
+ from_args_list = parser.FUNCTIONS.get(name.upper())
+
+ if from_args_list:
+ function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore
+ else:
+ kwargs = kwargs or {"expressions": args}
+ function = Anonymous(this=name, **kwargs)
+
+ for error_message in function.error_messages(args):
+ raise ValueError(error_message)
+
+ return function
+
+
def true():
+ """
+ Returns a true Boolean expression.
+ """
return Boolean(this=True)
def false():
+ """
+ Returns a false Boolean expression.
+ """
return Boolean(this=False)
def null():
+ """
+ Returns a Null expression.
+ """
return Null()