diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:35 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:35 +0000 |
commit | d1f00706bff58b863b0a1c5bf4adf39d36049d4c (patch) | |
tree | 3a8ecc5d1509d655d5df6b1455bc1e309da2c02c /sqlglot/expressions.py | |
parent | Releasing debian version 9.0.6-1. (diff) | |
download | sqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.tar.xz sqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.zip |
Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 258 |
1 files changed, 169 insertions, 89 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1691d85..57a2c88 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import datetime import numbers import re +import typing as t from collections import deque from copy import deepcopy from enum import auto @@ -9,12 +12,15 @@ from sqlglot.errors import ParseError from sqlglot.helper import ( AutoName, camel_to_snake_case, - ensure_list, - list_get, + ensure_collection, + seq_get, split_num_words, subclasses, ) +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import Dialect + class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -35,27 +41,30 @@ class Expression(metaclass=_Expression): or optional (False). """ - key = None + key = "Expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "type") + __slots__ = ("args", "parent", "arg_key", "type", "comment") def __init__(self, **args): self.args = args self.parent = None self.arg_key = None self.type = None + self.comment = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) - def __eq__(self, other): + def __eq__(self, other) -> bool: return type(self) is type(other) and _norm_args(self) == _norm_args(other) - def __hash__(self): + def __hash__(self) -> int: return hash( ( self.key, - tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()), + tuple( + (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items() + ), ) ) @@ -79,6 +88,19 @@ class Expression(metaclass=_Expression): return field.this return "" + def find_comment(self, key: str) -> str: + """ + Finds the comment that is attached to a specified child node. + + Args: + key: the key of the target child node (e.g. "this", "expression", etc). + + Returns: + The comment attached to the child node, or the empty string, if it doesn't exist. + """ + field = self.args.get(key) + return field.comment if isinstance(field, Expression) else "" + @property def is_string(self): return isinstance(self, Literal) and self.args["is_string"] @@ -114,7 +136,10 @@ class Expression(metaclass=_Expression): return self.alias or self.name def __deepcopy__(self, memo): - return self.__class__(**deepcopy(self.args)) + copy = self.__class__(**deepcopy(self.args)) + copy.comment = self.comment + copy.type = self.type + return copy def copy(self): new = deepcopy(self) @@ -249,9 +274,7 @@ class Expression(metaclass=_Expression): return for k, v in self.args.items(): - nodes = ensure_list(v) - - for node in nodes: + for node in ensure_collection(v): if isinstance(node, Expression): yield from node.dfs(self, k, prune) @@ -274,9 +297,7 @@ class Expression(metaclass=_Expression): if isinstance(item, Expression): for k, v in item.args.items(): - nodes = ensure_list(v) - - for node in nodes: + for node in ensure_collection(v): if isinstance(node, Expression): queue.append((node, item, k)) @@ -319,7 +340,7 @@ class Expression(metaclass=_Expression): def __repr__(self): return self.to_s() - def sql(self, dialect=None, **opts): + def sql(self, dialect: Dialect | str | None = None, **opts) -> str: """ Returns SQL string representation of this tree. @@ -335,7 +356,7 @@ class Expression(metaclass=_Expression): return Dialect.get_or_raise(dialect)().generate(self, **opts) - def to_s(self, hide_missing=True, level=0): + def to_s(self, hide_missing: bool = True, level: int = 0) -> str: indent = "" if not level else "\n" indent += "".join([" "] * level) left = f"({self.key.upper()} " @@ -343,11 +364,13 @@ class Expression(metaclass=_Expression): args = { k: ", ".join( v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v) - for v in ensure_list(vs) + for v in ensure_collection(vs) if v is not None ) for k, vs in self.args.items() } + args["comment"] = self.comment + args["type"] = self.type args = {k: v for k, v in args.items() if v or not hide_missing} right = ", ".join(f"{k}: {v}" for k, v in args.items()) @@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable): pass -class Annotation(Expression): - arg_types = { - "this": True, - "expression": True, - } - - @property - def alias(self): - return self.expression.alias_or_name - - class Cache(Expression): arg_types = { "with": False, @@ -623,6 +635,38 @@ class Describe(Expression): pass +class Set(Expression): + arg_types = {"expressions": True} + + +class SetItem(Expression): + arg_types = { + "this": True, + "kind": False, + "collate": False, # MySQL SET NAMES statement + } + + +class Show(Expression): + arg_types = { + "this": True, + "target": False, + "offset": False, + "limit": False, + "like": False, + "where": False, + "db": False, + "full": False, + "mutex": False, + "query": False, + "channel": False, + "global": False, + "log": False, + "position": False, + "types": False, + } + + class UserDefinedFunction(Expression): arg_types = {"this": True, "expressions": False} @@ -864,18 +908,20 @@ class Literal(Condition): def __eq__(self, other): return ( - isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"] + isinstance(other, Literal) + and self.this == other.this + and self.args["is_string"] == other.args["is_string"] ) def __hash__(self): return hash((self.key, self.this, self.args["is_string"])) @classmethod - def number(cls, number): + def number(cls, number) -> Literal: return cls(this=str(number), is_string=False) @classmethod - def string(cls, string): + def string(cls, string) -> Literal: return cls(this=str(string), is_string=True) @@ -1087,7 +1133,7 @@ class Properties(Expression): } @classmethod - def from_dict(cls, properties_dict): + def from_dict(cls, properties_dict) -> Properties: expressions = [] for key, value in properties_dict.items(): property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty) @@ -1323,7 +1369,7 @@ class Select(Subqueryable): **QUERY_MODIFIERS, } - def from_(self, *expressions, append=True, dialect=None, copy=True, **opts): + def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the FROM expression. @@ -1356,7 +1402,7 @@ class Select(Subqueryable): **opts, ) - def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the GROUP BY expression. @@ -1392,7 +1438,7 @@ class Select(Subqueryable): **opts, ) - def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the ORDER BY expression. @@ -1425,7 +1471,7 @@ class Select(Subqueryable): **opts, ) - def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the SORT BY expression. @@ -1458,7 +1504,7 @@ class Select(Subqueryable): **opts, ) - def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Set the CLUSTER BY expression. @@ -1491,7 +1537,7 @@ class Select(Subqueryable): **opts, ) - def limit(self, expression, dialect=None, copy=True, **opts): + def limit(self, expression, dialect=None, copy=True, **opts) -> Select: """ Set the LIMIT expression. @@ -1522,7 +1568,7 @@ class Select(Subqueryable): **opts, ) - def offset(self, expression, dialect=None, copy=True, **opts): + def offset(self, expression, dialect=None, copy=True, **opts) -> Select: """ Set the OFFSET expression. @@ -1553,7 +1599,7 @@ class Select(Subqueryable): **opts, ) - def select(self, *expressions, append=True, dialect=None, copy=True, **opts): + def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the SELECT expressions. @@ -1583,7 +1629,7 @@ class Select(Subqueryable): **opts, ) - def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts): + def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the LATERAL expressions. @@ -1626,7 +1672,7 @@ class Select(Subqueryable): dialect=None, copy=True, **opts, - ): + ) -> Select: """ Append to or set the JOIN expressions. @@ -1672,7 +1718,7 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: - natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore if natural: join.set("natural", True) if side: @@ -1681,12 +1727,12 @@ class Select(Subqueryable): join.set("kind", kind.text) if on: - on = and_(*ensure_list(on), dialect=dialect, **opts) + on = and_(*ensure_collection(on), dialect=dialect, **opts) join.set("on", on) if using: join = _apply_list_builder( - *ensure_list(using), + *ensure_collection(using), instance=join, arg="using", append=append, @@ -1705,7 +1751,7 @@ class Select(Subqueryable): **opts, ) - def where(self, *expressions, append=True, dialect=None, copy=True, **opts): + def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the WHERE expressions. @@ -1737,7 +1783,7 @@ class Select(Subqueryable): **opts, ) - def having(self, *expressions, append=True, dialect=None, copy=True, **opts): + def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: """ Append to or set the HAVING expressions. @@ -1769,7 +1815,7 @@ class Select(Subqueryable): **opts, ) - def distinct(self, distinct=True, copy=True): + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -1788,7 +1834,7 @@ class Select(Subqueryable): instance.set("distinct", Distinct() if distinct else None) return instance - def ctas(self, table, properties=None, dialect=None, copy=True, **opts): + def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create: """ Convert this expression to a CREATE TABLE AS statement. @@ -1826,11 +1872,11 @@ class Select(Subqueryable): ) @property - def named_selects(self): + def named_selects(self) -> t.List[str]: return [e.alias_or_name for e in self.expressions if e.alias_or_name] @property - def selects(self): + def selects(self) -> t.List[Expression]: return self.expressions @@ -1910,12 +1956,16 @@ class Parameter(Expression): pass +class SessionParameter(Expression): + arg_types = {"this": True, "kind": False} + + class Placeholder(Expression): arg_types = {"this": False} class Null(Condition): - arg_types = {} + arg_types: t.Dict[str, t.Any] = {} class Boolean(Condition): @@ -1936,6 +1986,7 @@ class DataType(Expression): NVARCHAR = auto() TEXT = auto() BINARY = auto() + VARBINARY = auto() INT = auto() TINYINT = auto() SMALLINT = auto() @@ -1975,7 +2026,7 @@ class DataType(Expression): UNKNOWN = auto() # Sentinel value, useful for type annotation @classmethod - def build(cls, dtype, **kwargs): + def build(cls, dtype, **kwargs) -> DataType: return DataType( this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], **kwargs, @@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate): pass +class NullSafeEQ(Binary, Predicate): + pass + + +class NullSafeNEQ(Binary, Predicate): + pass + + +class Distance(Binary): + pass + + class Escape(Binary): pass @@ -2101,15 +2164,11 @@ class Is(Binary, Predicate): pass -class Like(Binary, Predicate): - pass - - -class SimilarTo(Binary, Predicate): - pass +class Kwarg(Binary): + """Kwarg in special functions like func(kwarg => y).""" -class Distance(Binary): +class Like(Binary, Predicate): pass @@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate): pass +class SimilarTo(Binary, Predicate): + pass + + class Sub(Binary): pass @@ -2189,7 +2252,13 @@ class Distinct(Expression): class In(Predicate): - arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False} + arg_types = { + "this": True, + "expressions": False, + "query": False, + "unnest": False, + "field": False, + } class TimeUnit(Expression): @@ -2255,7 +2324,9 @@ class Func(Condition): @classmethod def sql_names(cls): if cls is Func: - raise NotImplementedError("SQL name is only supported by concrete function implementations") + raise NotImplementedError( + "SQL name is only supported by concrete function implementations" + ) if not hasattr(cls, "_sql_names"): cls._sql_names = [camel_to_snake_case(cls.__name__)] return cls._sql_names @@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit): arg_types = {"this": True, "expression": True, "unit": False} -class DateTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} +class DateTrunc(Func): + arg_types = {"this": True, "expression": True, "zone": False} class DatetimeAdd(Func, TimeUnit): @@ -2791,6 +2862,10 @@ class Year(Func): pass +class Use(Expression): + pass + + def _norm_args(expression): args = {} @@ -2822,7 +2897,7 @@ def maybe_parse( dialect=None, prefix=None, **opts, -): +) -> t.Optional[Expression]: """Gracefully handle a possible string or expression. Example: @@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts): return Except(this=left, expression=right, distinct=distinct) -def select(*expressions, dialect=None, **opts): +def select(*expressions, dialect=None, **opts) -> Select: """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts): return Select().select(*expressions, dialect=dialect, **opts) -def from_(*expressions, dialect=None, **opts): +def from_(*expressions, dialect=None, **opts) -> Select: """ Initializes a syntax tree from a FROM expression. @@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts): return Select().from_(*expressions, dialect=dialect, **opts) -def update(table, properties, where=None, from_=None, dialect=None, **opts): +def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update: """ Creates an update statement. @@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) update.set( "expressions", - [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()], + [ + EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) + for k, v in properties.items() + ], ) if from_: update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) @@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts): return update -def delete(table, where=None, dialect=None, **opts): +def delete(table, where=None, dialect=None, **opts) -> Delete: """ Builds a delete statement. @@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts): ) -def condition(expression, dialect=None, **opts): +def condition(expression, dialect=None, **opts) -> Condition: """ Initialize a logical condition expression. @@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts): Returns: Condition: the expression """ - return maybe_parse( + return maybe_parse( # type: ignore expression, into=Condition, dialect=dialect, @@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts): ) -def and_(*expressions, dialect=None, **opts): +def and_(*expressions, dialect=None, **opts) -> And: """ Combine multiple conditions with an AND logical operator. @@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts): return _combine(expressions, And, dialect, **opts) -def or_(*expressions, dialect=None, **opts): +def or_(*expressions, dialect=None, **opts) -> Or: """ Combine multiple conditions with an OR logical operator. @@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts): return _combine(expressions, Or, dialect, **opts) -def not_(expression, dialect=None, **opts): +def not_(expression, dialect=None, **opts) -> Not: """ Wrap a condition with a NOT operator. @@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts): return Not(this=_wrap_operator(this)) -def paren(expression): +def paren(expression) -> Paren: return Paren(this=expression) SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") -def to_identifier(alias, quoted=None): +def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: if alias is None: return None if isinstance(alias, Identifier): @@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None): return identifier -def to_table(sql_path: str, **kwargs) -> Table: +def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. - If a table is passed in then that table is returned. Args: - sql_path(str|Table): `[catalog].[schema].[table]` string + sql_path: a `[catalog].[schema].[table]` string. + Returns: - Table: A table expression + A table expression. """ if sql_path is None or isinstance(sql_path, Table): return sql_path @@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts): return Select().from_(expression, dialect=dialect, **opts) -def column(col, table=None, quoted=None): +def column(col, table=None, quoted=None) -> Column: """ Build a Column. Args: @@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None): ) -def table_(table, db=None, catalog=None, quoted=None, alias=None): +def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: """Build a Table. Args: @@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None): ) -def values(values, alias=None): +def values(values, alias=None) -> Values: """Build VALUES statement. Example: @@ -3449,7 +3527,7 @@ def values(values, alias=None): ) -def convert(value): +def convert(value) -> Expression: """Convert a python value into an expression object. Raises an error if a conversion is not possible. @@ -3500,15 +3578,14 @@ def replace_children(expression, fun): for cn in child_nodes: if isinstance(cn, Expression): - cns = ensure_list(fun(cn)) - for child_node in cns: + for child_node in ensure_collection(fun(cn)): new_child_nodes.append(child_node) child_node.parent = expression child_node.arg_key = k else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0) + expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) def column_table_names(expression): @@ -3529,7 +3606,7 @@ def column_table_names(expression): return list(dict.fromkeys(column.table for column in expression.find_all(Column))) -def table_name(table): +def table_name(table) -> str: """Get the full name of a table as a string. Args: @@ -3546,6 +3623,9 @@ def table_name(table): table = maybe_parse(table, into=Table) + if not table: + raise ValueError(f"Cannot parse {table}") + return ".".join( part for part in ( |