diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-30 17:08:37 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-30 17:08:37 +0000 |
commit | be1cb18ea28222fca384a5459a024b7e9af5cadb (patch) | |
tree | 4698c9069380a7c30ceb51129f93f6c8662315e4 /sqlglot/expressions.py | |
parent | Releasing debian version 10.5.6-1. (diff) | |
download | sqlglot-be1cb18ea28222fca384a5459a024b7e9af5cadb.tar.xz sqlglot-be1cb18ea28222fca384a5459a024b7e9af5cadb.zip |
Merging upstream version 10.5.10.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 410 |
1 files changed, 340 insertions, 70 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index be99fe2..f9751ca 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,5 +1,12 @@ """ -.. include:: ../pdoc/docs/expressions.md +## Expressions + +Every AST node in SQLGlot is represented by a subclass of `Expression`. + +This module contains the implementation of all supported `Expression` types. Additionally, +it exposes a number of helper functions, which are mainly used to programmatically build +SQL expressions, such as `sqlglot.expressions.select`. +---- """ from __future__ import annotations @@ -27,35 +34,66 @@ from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import Dialect + IntoType = t.Union[ + str, + t.Type[Expression], + t.Collection[t.Union[str, t.Type[Expression]]], + ] + class _Expression(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) + + # When an Expression class is created, its key is automatically set to be + # the lowercase version of the class' name. klass.key = clsname.lower() + + # This is so that docstrings are not inherited in pdoc + klass.__doc__ = klass.__doc__ or "" + return klass class Expression(metaclass=_Expression): """ - The base class for all expressions in a syntax tree. + The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary + context, such as its child expressions, their names (arg keys), and whether a given child expression + is optional or not. Attributes: - arg_types (dict): determines arguments supported by this expression. - The key in a dictionary defines a unique key of an argument using - which the argument's value can be retrieved. The value is a boolean - flag which indicates whether the argument's value is required (True) - or optional (False). + key: a unique key for each class in the Expression hierarchy. This is useful for hashing + and representing expressions as strings. + arg_types: determines what arguments (child nodes) are supported by an expression. It + maps arg keys to booleans that indicate whether the corresponding args are optional. + + Example: + >>> class Foo(Expression): + ... arg_types = {"this": True, "expression": False} + + The above definition informs us that Foo is an Expression that requires an argument called + "this" and may also optionally receive an argument called "expression". + + Args: + args: a mapping used for retrieving the arguments of an expression, given their arg keys. + parent: a reference to the parent expression (or None, in case of root expressions). + arg_key: the arg key an expression is associated with, i.e. the name its parent expression + uses to refer to it. + comments: a list of comments that are associated with a given expression. This is used in + order to preserve comments when transpiling SQL code. + _type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the + optimizer, in order to enable some transformations that require type information. """ - key = "Expression" + key = "expression" arg_types = {"this": True} __slots__ = ("args", "parent", "arg_key", "comments", "_type") - def __init__(self, **args): - self.args = args - self.parent = None - self.arg_key = None - self.comments = None + def __init__(self, **args: t.Any): + self.args: t.Dict[str, t.Any] = args + self.parent: t.Optional[Expression] = None + self.arg_key: t.Optional[str] = None + self.comments: t.Optional[t.List[str]] = None self._type: t.Optional[DataType] = None for arg_key, value in self.args.items(): @@ -76,17 +114,30 @@ class Expression(metaclass=_Expression): @property def this(self): + """ + Retrieves the argument with key "this". + """ return self.args.get("this") @property def expression(self): + """ + Retrieves the argument with key "expression". + """ return self.args.get("expression") @property def expressions(self): + """ + Retrieves the argument with key "expressions". + """ return self.args.get("expressions") or [] def text(self, key): + """ + Returns a textual representation of the argument corresponding to "key". This can only be used + for args that are strings or leaf Expression instances, such as identifiers and literals. + """ field = self.args.get(key) if isinstance(field, str): return field @@ -96,14 +147,23 @@ class Expression(metaclass=_Expression): @property def is_string(self): + """ + Checks whether a Literal expression is a string. + """ return isinstance(self, Literal) and self.args["is_string"] @property def is_number(self): + """ + Checks whether a Literal expression is a number. + """ return isinstance(self, Literal) and not self.args["is_string"] @property def is_int(self): + """ + Checks whether a Literal expression is an integer. + """ if self.is_number: try: int(self.name) @@ -114,6 +174,9 @@ class Expression(metaclass=_Expression): @property def alias(self): + """ + Returns the alias of the expression, or an empty string if it's not aliased. + """ if isinstance(self.args.get("alias"), TableAlias): return self.args["alias"].name return self.text("alias") @@ -129,6 +192,24 @@ class Expression(metaclass=_Expression): return self.alias or self.name @property + def output_name(self): + """ + Name of the output column if this expression is a selection. + + If the Expression has no output name, an empty string is returned. + + Example: + >>> from sqlglot import parse_one + >>> parse_one("SELECT a").expressions[0].output_name + 'a' + >>> parse_one("SELECT b AS c").expressions[0].output_name + 'c' + >>> parse_one("SELECT 1 + 2").expressions[0].output_name + '' + """ + return "" + + @property def type(self) -> t.Optional[DataType]: return self._type @@ -145,6 +226,9 @@ class Expression(metaclass=_Expression): return copy def copy(self): + """ + Returns a deep copy of the expression. + """ new = deepcopy(self) for item, parent, _ in new.bfs(): if isinstance(item, Expression) and parent: @@ -169,7 +253,7 @@ class Expression(metaclass=_Expression): Sets `arg_key` to `value`. Args: - arg_key (str): name of the expression arg + arg_key (str): name of the expression arg. value: value to set the arg to. """ self.args[arg_key] = value @@ -203,8 +287,7 @@ class Expression(metaclass=_Expression): expression_types (type): the expression type(s) to match. Returns: - the node which matches the criteria or None if no node matching - the criteria was found. + The node which matches the criteria or None if no such node was found. """ return next(self.find_all(*expression_types, bfs=bfs), None) @@ -217,7 +300,7 @@ class Expression(metaclass=_Expression): expression_types (type): the expression type(s) to match. Returns: - the generator object. + The generator object. """ for expression, _, _ in self.walk(bfs=bfs): if isinstance(expression, expression_types): @@ -231,7 +314,7 @@ class Expression(metaclass=_Expression): expression_types (type): the expression type(s) to match. Returns: - the parent node + The parent node. """ ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): @@ -269,7 +352,7 @@ class Expression(metaclass=_Expression): the DFS (Depth-first) order. Returns: - the generator object. + The generator object. """ parent = parent or self.parent yield self, parent, key @@ -287,7 +370,7 @@ class Expression(metaclass=_Expression): the BFS (Breadth-first) order. Returns: - the generator object. + The generator object. """ queue = deque([(self, self.parent, None)]) @@ -341,32 +424,33 @@ class Expression(metaclass=_Expression): return self.sql() def __repr__(self): - return self.to_s() + return self._to_s() def sql(self, dialect: Dialect | str | None = None, **opts) -> str: """ Returns SQL string representation of this tree. - Args - dialect (str): the dialect of the output SQL string - (eg. "spark", "hive", "presto", "mysql"). - opts (dict): other :class:`~sqlglot.generator.Generator` options. + Args: + dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql"). + opts: other `sqlglot.generator.Generator` options. - Returns - the SQL string. + Returns: + The SQL string. """ from sqlglot.dialects import Dialect return Dialect.get_or_raise(dialect)().generate(self, **opts) - def to_s(self, hide_missing: bool = True, level: int = 0) -> str: + def _to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" indent += "".join([" "] * level) left = f"({self.key.upper()} " args: t.Dict[str, t.Any] = { k: ", ".join( - v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) + v._to_s(hide_missing=hide_missing, level=level + 1) + if hasattr(v, "_to_s") + else str(v) for v in ensure_collection(vs) if v is not None ) @@ -394,7 +478,7 @@ class Expression(metaclass=_Expression): modified in place. Returns: - the transformed tree. + The transformed tree. """ node = self.copy() if copy else self new_node = fun(node, *args, **kwargs) @@ -423,8 +507,8 @@ class Expression(metaclass=_Expression): Args: expression (Expression|None): new node - Returns : - the new expression or expressions + Returns: + The new expression or expressions. """ if not self.parent: return expression @@ -458,6 +542,40 @@ class Expression(metaclass=_Expression): assert isinstance(self, type_) return self + def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: + """ + Checks if this expression is valid (e.g. all mandatory args are set). + + Args: + args: a sequence of values that were used to instantiate a Func expression. This is used + to check that the provided arguments don't exceed the function argument limit. + + Returns: + A list of error messages for all possible errors that were found. + """ + errors: t.List[str] = [] + + for k in self.args: + if k not in self.arg_types: + errors.append(f"Unexpected keyword: '{k}' for {self.__class__}") + for k, mandatory in self.arg_types.items(): + v = self.args.get(k) + if mandatory and (v is None or (isinstance(v, list) and not v)): + errors.append(f"Required keyword: '{k}' missing for {self.__class__}") + + if ( + args + and isinstance(self, Func) + and len(args) > len(self.arg_types) + and not self.is_var_len_args + ): + errors.append( + f"The number of provided arguments ({len(args)}) is greater than " + f"the maximum number of supported arguments ({len(self.arg_types)})" + ) + + return errors + def dump(self): """ Dump this Expression to a JSON-serializable dict. @@ -552,7 +670,7 @@ class DerivedTable(Expression): @property def named_selects(self): - return [select.alias_or_name for select in self.selects] + return [select.output_name for select in self.selects] class Unionable(Expression): @@ -654,6 +772,7 @@ class Create(Expression): "no_primary_index": False, "indexes": False, "no_schema_binding": False, + "begin": False, } @@ -696,7 +815,7 @@ class Show(Expression): class UserDefinedFunction(Expression): - arg_types = {"this": True, "expressions": False} + arg_types = {"this": True, "expressions": False, "wrapped": False} class UserDefinedFunctionKwarg(Expression): @@ -750,6 +869,10 @@ class Column(Condition): def table(self): return self.text("table") + @property + def output_name(self): + return self.name + class ColumnDef(Expression): arg_types = { @@ -865,6 +988,10 @@ class ForeignKey(Expression): } +class PrimaryKey(Expression): + arg_types = {"expressions": True, "options": False} + + class Unique(Expression): arg_types = {"expressions": True} @@ -904,6 +1031,10 @@ class Identifier(Expression): def __hash__(self): return hash((self.key, self.this.lower())) + @property + def output_name(self): + return self.name + class Index(Expression): arg_types = { @@ -996,6 +1127,10 @@ class Literal(Condition): def string(cls, string) -> Literal: return cls(this=str(string), is_string=True) + @property + def output_name(self): + return self.name + class Join(Expression): arg_types = { @@ -1186,7 +1321,7 @@ class SchemaCommentProperty(Property): class ReturnsProperty(Property): - arg_types = {"this": True, "is_table": False} + arg_types = {"this": True, "is_table": False, "table": False} class LanguageProperty(Property): @@ -1262,8 +1397,13 @@ class Qualify(Expression): pass +# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql +class Return(Expression): + pass + + class Reference(Expression): - arg_types = {"this": True, "expressions": True} + arg_types = {"this": True, "expressions": False, "options": False} class Tuple(Expression): @@ -1397,6 +1537,16 @@ class Table(Expression): "joins": False, "pivots": False, "hints": False, + "system_time": False, + } + + +# See the TSQL "Querying data in a system-versioned temporal table" page +class SystemTime(Expression): + arg_types = { + "this": False, + "expression": False, + "kind": True, } @@ -2027,7 +2177,7 @@ class Select(Subqueryable): @property def named_selects(self) -> t.List[str]: - return [e.alias_or_name for e in self.expressions if e.alias_or_name] + return [e.output_name for e in self.expressions if e.alias_or_name] @property def selects(self) -> t.List[Expression]: @@ -2051,6 +2201,10 @@ class Subquery(DerivedTable, Unionable): expression = expression.this return expression + @property + def output_name(self): + return self.alias + class TableSample(Expression): arg_types = { @@ -2066,6 +2220,16 @@ class TableSample(Expression): } +class Tag(Expression): + """Tags are used for generating arbitrary sql like SELECT <span>x</span>.""" + + arg_types = { + "this": False, + "prefix": False, + "postfix": False, + } + + class Pivot(Expression): arg_types = { "this": False, @@ -2106,6 +2270,10 @@ class Star(Expression): def name(self): return "*" + @property + def output_name(self): + return self.name + class Parameter(Expression): pass @@ -2143,6 +2311,8 @@ class DataType(Expression): TEXT = auto() MEDIUMTEXT = auto() LONGTEXT = auto() + MEDIUMBLOB = auto() + LONGBLOB = auto() BINARY = auto() VARBINARY = auto() INT = auto() @@ -2282,11 +2452,11 @@ class Rollback(Expression): class AlterTable(Expression): - arg_types = { - "this": True, - "actions": True, - "exists": False, - } + arg_types = {"this": True, "actions": True, "exists": False} + + +class AddConstraint(Expression): + arg_types = {"this": False, "expression": False, "enforced": False} # Binary expressions like (ADD a b) @@ -2456,6 +2626,10 @@ class Neg(Unary): class Alias(Expression): arg_types = {"this": True, "alias": False} + @property + def output_name(self): + return self.alias + class Aliases(Expression): arg_types = {"this": True, "expressions": True} @@ -2523,16 +2697,13 @@ class Func(Condition): """ The base class for all function expressions. - Attributes - is_var_len_args (bool): if set to True the last argument defined in - arg_types will be treated as a variable length argument and the - argument's value will be stored as a list. - _sql_names (list): determines the SQL name (1st item in the list) and - aliases (subsequent items) for this function expression. These - values are used to map this node to a name during parsing as well - as to provide the function's name during SQL string generation. By - default the SQL name is set to the expression's class name transformed - to snake case. + Attributes: + is_var_len_args (bool): if set to True the last argument defined in arg_types will be + treated as a variable length argument and the argument's value will be stored as a list. + _sql_names (list): determines the SQL name (1st item in the list) and aliases (subsequent items) + for this function expression. These values are used to map this node to a name during parsing + as well as to provide the function's name during SQL string generation. By default the SQL + name is set to the expression's class name transformed to snake case. """ is_var_len_args = False @@ -2558,7 +2729,7 @@ class Func(Condition): raise NotImplementedError( "SQL name is only supported by concrete function implementations" ) - if not hasattr(cls, "_sql_names"): + if "_sql_names" not in cls.__dict__: cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2658,6 +2829,10 @@ class Cast(Func): def to(self): return self.args["to"] + @property + def output_name(self): + return self.name + class Collate(Binary): pass @@ -2956,6 +3131,14 @@ class Pow(Func): _sql_names = ["POWER", "POW"] +class PercentileCont(AggFunc): + pass + + +class PercentileDisc(AggFunc): + pass + + class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} @@ -3213,12 +3396,13 @@ def _norm_arg(arg): ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) +# Helpers def maybe_parse( - sql_or_expression, + sql_or_expression: str | Expression, *, - into=None, - dialect=None, - prefix=None, + into: t.Optional[IntoType] = None, + dialect: t.Optional[str] = None, + prefix: t.Optional[str] = None, **opts, ) -> Expression: """Gracefully handle a possible string or expression. @@ -3230,11 +3414,11 @@ def maybe_parse( (IDENTIFIER this: x, quoted: False) Args: - sql_or_expression (str | Expression): the SQL code string or an expression - into (Expression): the SQLGlot Expression to parse into - dialect (str): the dialect used to parse the input expressions (in the case that an + sql_or_expression: the SQL code string or an expression + into: the SQLGlot Expression to parse into + dialect: the dialect used to parse the input expressions (in the case that an input expression is a SQL string). - prefix (str): a string to prefix the sql with before it gets parsed + prefix: a string to prefix the sql with before it gets parsed (automatically includes a space) **opts: other options to use to parse the input expressions (again, in the case that an input expression is a SQL string). @@ -3993,7 +4177,7 @@ def table_name(table) -> str: """Get the full name of a table as a string. Args: - table (exp.Table | str): Table expression node or string. + table (exp.Table | str): table expression node or string. Examples: >>> from sqlglot import exp, parse_one @@ -4001,7 +4185,7 @@ def table_name(table) -> str: 'a.b.c' Returns: - str: the table name + The table name. """ table = maybe_parse(table, into=Table) @@ -4024,8 +4208,8 @@ def replace_tables(expression, mapping): """Replace all tables in expression according to the mapping. Args: - expression (sqlglot.Expression): Expression node to be transformed and replaced - mapping (Dict[str, str]): Mapping of table names + expression (sqlglot.Expression): expression node to be transformed and replaced. + mapping (Dict[str, str]): mapping of table names. Examples: >>> from sqlglot import exp, parse_one @@ -4033,7 +4217,7 @@ def replace_tables(expression, mapping): 'SELECT * FROM c' Returns: - The mapped expression + The mapped expression. """ def _replace_tables(node): @@ -4053,9 +4237,9 @@ def replace_placeholders(expression, *args, **kwargs): """Replace placeholders in an expression. Args: - expression (sqlglot.Expression): Expression node to be transformed and replaced - args: Positional names that will substitute unnamed placeholders in the given order - kwargs: Keyword arguments that will substitute named placeholders + expression (sqlglot.Expression): expression node to be transformed and replaced. + args: positional names that will substitute unnamed placeholders in the given order. + kwargs: keyword arguments that will substitute named placeholders. Examples: >>> from sqlglot import exp, parse_one @@ -4065,7 +4249,7 @@ def replace_placeholders(expression, *args, **kwargs): 'SELECT * FROM foo WHERE a = b' Returns: - The mapped expression + The mapped expression. """ def _replace_placeholders(node, args, **kwargs): @@ -4084,15 +4268,101 @@ def replace_placeholders(expression, *args, **kwargs): return expression.transform(_replace_placeholders, iter(args), **kwargs) +def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression: + """Transforms an expression by expanding all referenced sources into subqueries. + + Examples: + >>> from sqlglot import parse_one + >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql() + 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */' + + Args: + expression: The expression to expand. + sources: A dictionary of name to Subqueryables. + copy: Whether or not to copy the expression during transformation. Defaults to True. + + Returns: + The transformed expression. + """ + + def _expand(node: Expression): + if isinstance(node, Table): + name = table_name(node) + source = sources.get(name) + if source: + subquery = source.subquery(node.alias or name) + subquery.comments = [f"source: {name}"] + return subquery + return node + + return expression.transform(_expand, copy=copy) + + +def func(name: str, *args, dialect: t.Optional[Dialect | str] = None, **kwargs) -> Func: + """ + Returns a Func expression. + + Examples: + >>> func("abs", 5).sql() + 'ABS(5)' + + >>> func("cast", this=5, to=DataType.build("DOUBLE")).sql() + 'CAST(5 AS DOUBLE)' + + Args: + name: the name of the function to build. + args: the args used to instantiate the function of interest. + dialect: the source dialect. + kwargs: the kwargs used to instantiate the function of interest. + + Note: + The arguments `args` and `kwargs` are mutually exclusive. + + Returns: + An instance of the function of interest, or an anonymous function, if `name` doesn't + correspond to an existing `sqlglot.expressions.Func` class. + """ + if args and kwargs: + raise ValueError("Can't use both args and kwargs to instantiate a function.") + + from sqlglot.dialects.dialect import Dialect + + args = tuple(convert(arg) for arg in args) + kwargs = {key: convert(value) for key, value in kwargs.items()} + + parser = Dialect.get_or_raise(dialect)().parser() + from_args_list = parser.FUNCTIONS.get(name.upper()) + + if from_args_list: + function = from_args_list(args) if args else from_args_list.__self__(**kwargs) # type: ignore + else: + kwargs = kwargs or {"expressions": args} + function = Anonymous(this=name, **kwargs) + + for error_message in function.error_messages(args): + raise ValueError(error_message) + + return function + + def true(): + """ + Returns a true Boolean expression. + """ return Boolean(this=True) def false(): + """ + Returns a false Boolean expression. + """ return Boolean(this=False) def null(): + """ + Returns a Null expression. + """ return Null() |