diff options
Diffstat (limited to 'sqlglot/expressions.py')
-rw-r--r-- | sqlglot/expressions.py | 1354 |
1 files changed, 863 insertions, 491 deletions
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 9e7379d..a4c4e95 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -21,6 +21,7 @@ from collections import deque from copy import deepcopy from enum import auto +from sqlglot._typing import E from sqlglot.errors import ParseError from sqlglot.helper import ( AutoName, @@ -28,7 +29,6 @@ from sqlglot.helper import ( ensure_collection, ensure_list, seq_get, - split_num_words, subclasses, ) from sqlglot.tokens import Token @@ -36,8 +36,6 @@ from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType -E = t.TypeVar("E", bound="Expression") - class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -200,11 +198,11 @@ class Expression(metaclass=_Expression): return self.text("this") @property - def alias_or_name(self): + def alias_or_name(self) -> str: return self.alias or self.name @property - def output_name(self): + def output_name(self) -> str: """ Name of the output column if this expression is a selection. @@ -264,7 +262,7 @@ class Expression(metaclass=_Expression): if comments: self.comments.extend(comments) - def append(self, arg_key, value): + def append(self, arg_key: str, value: t.Any) -> None: """ Appends value to arg_key if it's a list or sets it as a new list. @@ -277,7 +275,7 @@ class Expression(metaclass=_Expression): self.args[arg_key].append(value) self._set_parent(arg_key, value) - def set(self, arg_key, value): + def set(self, arg_key: str, value: t.Any) -> None: """ Sets `arg_key` to `value`. @@ -288,7 +286,7 @@ class Expression(metaclass=_Expression): self.args[arg_key] = value self._set_parent(arg_key, value) - def _set_parent(self, arg_key, value): + def _set_parent(self, arg_key: str, value: t.Any) -> None: if hasattr(value, "parent"): value.parent = self value.arg_key = arg_key @@ -299,7 +297,7 @@ class Expression(metaclass=_Expression): v.arg_key = arg_key @property - def depth(self): + def depth(self) -> int: """ Returns the depth of this tree. """ @@ -318,26 +316,28 @@ class Expression(metaclass=_Expression): if hasattr(vs, "parent"): yield k, vs - def find(self, *expression_types: t.Type[E], bfs=True) -> E | None: + def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]: """ Returns the first node in this tree which matches at least one of the specified types. Args: expression_types: the expression type(s) to match. + bfs: whether to search the AST using the BFS algorithm (DFS is used if false). Returns: The node which matches the criteria or None if no such node was found. """ return next(self.find_all(*expression_types, bfs=bfs), None) - def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]: + def find_all(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Iterator[E]: """ Returns a generator object which visits all nodes in this tree and only yields those that match at least one of the specified expression types. Args: expression_types: the expression type(s) to match. + bfs: whether to search the AST using the BFS algorithm (DFS is used if false). Returns: The generator object. @@ -346,7 +346,7 @@ class Expression(metaclass=_Expression): if isinstance(expression, expression_types): yield expression - def find_ancestor(self, *expression_types: t.Type[E]) -> E | None: + def find_ancestor(self, *expression_types: t.Type[E]) -> t.Optional[E]: """ Returns a nearest parent matching expression_types. @@ -362,14 +362,14 @@ class Expression(metaclass=_Expression): return t.cast(E, ancestor) @property - def parent_select(self): + def parent_select(self) -> t.Optional[Select]: """ Returns the parent select statement. """ return self.find_ancestor(Select) @property - def same_parent(self): + def same_parent(self) -> bool: """Returns if the parent is the same class as itself.""" return type(self.parent) is self.__class__ @@ -469,10 +469,10 @@ class Expression(metaclass=_Expression): if not type(node) is self.__class__: yield node.unnest() if unnest else node - def __str__(self): + def __str__(self) -> str: return self.sql() - def __repr__(self): + def __repr__(self) -> str: return self._to_s() def sql(self, dialect: DialectType = None, **opts) -> str: @@ -541,6 +541,14 @@ class Expression(metaclass=_Expression): replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)) return new_node + @t.overload + def replace(self, expression: E) -> E: + ... + + @t.overload + def replace(self, expression: None) -> None: + ... + def replace(self, expression): """ Swap out this expression with a new expression. @@ -554,7 +562,7 @@ class Expression(metaclass=_Expression): 'SELECT y FROM tbl' Args: - expression (Expression|None): new node + expression: new node Returns: The new expression or expressions. @@ -568,7 +576,7 @@ class Expression(metaclass=_Expression): replace_children(parent, lambda child: expression if child is self else child) return expression - def pop(self): + def pop(self: E) -> E: """ Remove this expression from its AST. @@ -578,7 +586,7 @@ class Expression(metaclass=_Expression): self.replace(None) return self - def assert_is(self, type_): + def assert_is(self, type_: t.Type[E]) -> E: """ Assert that this `Expression` is an instance of `type_`. @@ -656,7 +664,13 @@ ExpOrStr = t.Union[str, Expression] class Condition(Expression): - def and_(self, *expressions, dialect=None, copy=True, **opts): + def and_( + self, + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Condition: """ AND this condition with one or multiple expressions. @@ -665,18 +679,24 @@ class Condition(Expression): 'x = 1 AND y = 1' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy the involved expressions (only applies to Expressions). - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: whether or not to copy the involved expressions (only applies to Expressions). + opts: other options to use to parse the input expressions. Returns: - And: the new condition. + The new And condition. """ return and_(self, *expressions, dialect=dialect, copy=copy, **opts) - def or_(self, *expressions, dialect=None, copy=True, **opts): + def or_( + self, + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Condition: """ OR this condition with one or multiple expressions. @@ -685,18 +705,18 @@ class Condition(Expression): 'x = 1 OR y = 1' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy the involved expressions (only applies to Expressions). - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: whether or not to copy the involved expressions (only applies to Expressions). + opts: other options to use to parse the input expressions. Returns: - Or: the new condition. + The new Or condition. """ return or_(self, *expressions, dialect=dialect, copy=copy, **opts) - def not_(self, copy=True): + def not_(self, copy: bool = True): """ Wrap this condition with NOT. @@ -705,14 +725,24 @@ class Condition(Expression): 'NOT x = 1' Args: - copy (bool): whether or not to copy this object. + copy: whether or not to copy this object. Returns: - Not: the new condition. + The new Not instance. """ return not_(self, copy=copy) - def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E: + def as_( + self, + alias: str | Identifier, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Alias: + return alias_(self, alias, quoted=quoted, dialect=dialect, copy=copy, **opts) + + def _binop(self, klass: t.Type[E], other: t.Any, reverse: bool = False) -> E: this = self.copy() other = convert(other, copy=True) if not isinstance(this, klass) and not isinstance(other, klass): @@ -728,7 +758,7 @@ class Condition(Expression): ) def isin( - self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts + self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts ) -> In: return In( this=_maybe_copy(self, copy), @@ -736,92 +766,95 @@ class Condition(Expression): query=maybe_parse(query, copy=copy, **opts) if query else None, ) - def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between: + def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between: return Between( this=_maybe_copy(self, copy), low=convert(low, copy=copy, **opts), high=convert(high, copy=copy, **opts), ) + def is_(self, other: ExpOrStr) -> Is: + return self._binop(Is, other) + def like(self, other: ExpOrStr) -> Like: return self._binop(Like, other) def ilike(self, other: ExpOrStr) -> ILike: return self._binop(ILike, other) - def eq(self, other: ExpOrStr) -> EQ: + def eq(self, other: t.Any) -> EQ: return self._binop(EQ, other) - def neq(self, other: ExpOrStr) -> NEQ: + def neq(self, other: t.Any) -> NEQ: return self._binop(NEQ, other) def rlike(self, other: ExpOrStr) -> RegexpLike: return self._binop(RegexpLike, other) - def __lt__(self, other: ExpOrStr) -> LT: + def __lt__(self, other: t.Any) -> LT: return self._binop(LT, other) - def __le__(self, other: ExpOrStr) -> LTE: + def __le__(self, other: t.Any) -> LTE: return self._binop(LTE, other) - def __gt__(self, other: ExpOrStr) -> GT: + def __gt__(self, other: t.Any) -> GT: return self._binop(GT, other) - def __ge__(self, other: ExpOrStr) -> GTE: + def __ge__(self, other: t.Any) -> GTE: return self._binop(GTE, other) - def __add__(self, other: ExpOrStr) -> Add: + def __add__(self, other: t.Any) -> Add: return self._binop(Add, other) - def __radd__(self, other: ExpOrStr) -> Add: + def __radd__(self, other: t.Any) -> Add: return self._binop(Add, other, reverse=True) - def __sub__(self, other: ExpOrStr) -> Sub: + def __sub__(self, other: t.Any) -> Sub: return self._binop(Sub, other) - def __rsub__(self, other: ExpOrStr) -> Sub: + def __rsub__(self, other: t.Any) -> Sub: return self._binop(Sub, other, reverse=True) - def __mul__(self, other: ExpOrStr) -> Mul: + def __mul__(self, other: t.Any) -> Mul: return self._binop(Mul, other) - def __rmul__(self, other: ExpOrStr) -> Mul: + def __rmul__(self, other: t.Any) -> Mul: return self._binop(Mul, other, reverse=True) - def __truediv__(self, other: ExpOrStr) -> Div: + def __truediv__(self, other: t.Any) -> Div: return self._binop(Div, other) - def __rtruediv__(self, other: ExpOrStr) -> Div: + def __rtruediv__(self, other: t.Any) -> Div: return self._binop(Div, other, reverse=True) - def __floordiv__(self, other: ExpOrStr) -> IntDiv: + def __floordiv__(self, other: t.Any) -> IntDiv: return self._binop(IntDiv, other) - def __rfloordiv__(self, other: ExpOrStr) -> IntDiv: + def __rfloordiv__(self, other: t.Any) -> IntDiv: return self._binop(IntDiv, other, reverse=True) - def __mod__(self, other: ExpOrStr) -> Mod: + def __mod__(self, other: t.Any) -> Mod: return self._binop(Mod, other) - def __rmod__(self, other: ExpOrStr) -> Mod: + def __rmod__(self, other: t.Any) -> Mod: return self._binop(Mod, other, reverse=True) - def __pow__(self, other: ExpOrStr) -> Pow: + def __pow__(self, other: t.Any) -> Pow: return self._binop(Pow, other) - def __rpow__(self, other: ExpOrStr) -> Pow: + def __rpow__(self, other: t.Any) -> Pow: return self._binop(Pow, other, reverse=True) - def __and__(self, other: ExpOrStr) -> And: + def __and__(self, other: t.Any) -> And: return self._binop(And, other) - def __rand__(self, other: ExpOrStr) -> And: + def __rand__(self, other: t.Any) -> And: return self._binop(And, other, reverse=True) - def __or__(self, other: ExpOrStr) -> Or: + def __or__(self, other: t.Any) -> Or: return self._binop(Or, other) - def __ror__(self, other: ExpOrStr) -> Or: + def __ror__(self, other: t.Any) -> Or: return self._binop(Or, other, reverse=True) def __neg__(self) -> Neg: @@ -837,12 +870,11 @@ class Predicate(Condition): class DerivedTable(Expression): @property - def alias_column_names(self): + def alias_column_names(self) -> t.List[str]: table_alias = self.args.get("alias") if not table_alias: return [] - column_list = table_alias.assert_is(TableAlias).args.get("columns") or [] - return [c.name for c in column_list] + return [c.name for c in table_alias.args.get("columns") or []] @property def selects(self): @@ -854,7 +886,9 @@ class DerivedTable(Expression): class Unionable(Expression): - def union(self, expression, distinct=True, dialect=None, **opts): + def union( + self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + ) -> Unionable: """ Builds a UNION expression. @@ -864,17 +898,20 @@ class Unionable(Expression): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - expression (str | Expression): the SQL code string. + expression: the SQL code string. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Union: the Union expression. + The new Union expression. """ return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) - def intersect(self, expression, distinct=True, dialect=None, **opts): + def intersect( + self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + ) -> Unionable: """ Builds an INTERSECT expression. @@ -884,17 +921,20 @@ class Unionable(Expression): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - expression (str | Expression): the SQL code string. + expression: the SQL code string. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Intersect: the Intersect expression + The new Intersect expression. """ return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) - def except_(self, expression, distinct=True, dialect=None, **opts): + def except_( + self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + ) -> Unionable: """ Builds an EXCEPT expression. @@ -904,13 +944,14 @@ class Unionable(Expression): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - expression (str | Expression): the SQL code string. + expression: the SQL code string. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Except: the Except expression + The new Except expression. """ return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) @@ -949,6 +990,17 @@ class Create(Expression): "indexes": False, "no_schema_binding": False, "begin": False, + "clone": False, + } + + +# https://docs.snowflake.com/en/sql-reference/sql/create-clone +class Clone(Expression): + arg_types = { + "this": True, + "when": False, + "kind": False, + "expression": False, } @@ -1038,6 +1090,10 @@ class ByteString(Condition): pass +class RawString(Condition): + pass + + class Column(Condition): arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False} @@ -1060,7 +1116,11 @@ class Column(Condition): @property def parts(self) -> t.List[Identifier]: """Return the parts of a column in order catalog, db, table, name.""" - return [part for part in reversed(list(self.args.values())) if part] + return [ + t.cast(Identifier, self.args[part]) + for part in ("catalog", "db", "table", "this") + if self.args.get(part) + ] def to_dot(self) -> Dot: """Converts the column into a dot expression.""" @@ -1116,6 +1176,27 @@ class Comment(Expression): arg_types = {"this": True, "kind": True, "expression": True, "exists": False} +# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl +class MergeTreeTTLAction(Expression): + arg_types = { + "this": True, + "delete": False, + "recompress": False, + "to_disk": False, + "to_volume": False, + } + + +# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl +class MergeTreeTTL(Expression): + arg_types = { + "expressions": True, + "where": False, + "group": False, + "aggregates": False, + } + + class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} @@ -1172,6 +1253,8 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = { "this": False, + "expression": False, + "on_null": False, "start": False, "increment": False, "minvalue": False, @@ -1202,7 +1285,7 @@ class TitleColumnConstraint(ColumnConstraintKind): class UniqueColumnConstraint(ColumnConstraintKind): - arg_types: t.Dict[str, t.Any] = {} + arg_types = {"this": False} class UppercaseColumnConstraint(ColumnConstraintKind): @@ -1255,7 +1338,7 @@ class Delete(Expression): def where( self, - *expressions: ExpOrStr, + *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -1367,10 +1450,6 @@ class PrimaryKey(Expression): arg_types = {"expressions": True, "options": False} -class Unique(Expression): - arg_types = {"expressions": True} - - # https://www.postgresql.org/docs/9.1/sql-selectinto.html # https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples class Into(Expression): @@ -1378,7 +1457,13 @@ class Into(Expression): class From(Expression): - arg_types = {"expressions": True} + @property + def name(self) -> str: + return self.this.name + + @property + def alias_or_name(self) -> str: + return self.this.alias_or_name class Having(Expression): @@ -1397,7 +1482,7 @@ class Identifier(Expression): arg_types = {"this": True, "quoted": False} @property - def quoted(self): + def quoted(self) -> bool: return bool(self.args.get("quoted")) @property @@ -1407,7 +1492,7 @@ class Identifier(Expression): return self.this.lower() @property - def output_name(self): + def output_name(self) -> str: return self.name @@ -1420,6 +1505,7 @@ class Index(Expression): "unique": False, "primary": False, "amp": False, # teradata + "partition_by": False, # teradata } @@ -1436,6 +1522,42 @@ class Insert(Expression): "alternative": False, } + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Insert: + """ + Append to or set the common table expressions. + + Example: + >>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql() + 'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts + ) + class OnConflict(Expression): arg_types = { @@ -1492,6 +1614,7 @@ class Group(Expression): "grouping_sets": False, "cube": False, "rollup": False, + "totals": False, } @@ -1519,7 +1642,7 @@ class Literal(Condition): return cls(this=str(string), is_string=True) @property - def output_name(self): + def output_name(self) -> str: return self.name @@ -1531,26 +1654,34 @@ class Join(Expression): "kind": False, "using": False, "natural": False, + "global": False, "hint": False, } @property - def kind(self): + def kind(self) -> str: return self.text("kind").upper() @property - def side(self): + def side(self) -> str: return self.text("side").upper() @property - def hint(self): + def hint(self) -> str: return self.text("hint").upper() @property - def alias_or_name(self): + def alias_or_name(self) -> str: return self.this.alias_or_name - def on(self, *expressions, append=True, dialect=None, copy=True, **opts): + def on( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Join: """ Append to or set the ON expressions. @@ -1560,17 +1691,17 @@ class Join(Expression): 'JOIN x ON y = 1' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. - append (bool): if `True`, AND the new expressions to any existing expression. + append: if `True`, AND the new expressions to any existing expression. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Join: the modified join expression. + The modified Join expression. """ join = _apply_conjunction_builder( *expressions, @@ -1587,7 +1718,14 @@ class Join(Expression): return join - def using(self, *expressions, append=True, dialect=None, copy=True, **opts): + def using( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Join: """ Append to or set the USING expressions. @@ -1597,16 +1735,16 @@ class Join(Expression): 'JOIN x USING (foo, bla)' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, concatenate the new expressions to the existing "using" list. + append: if `True`, concatenate the new expressions to the existing "using" list. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Join: the modified join expression. + The modified Join expression. """ join = _apply_list_builder( *expressions, @@ -1677,10 +1815,6 @@ class Property(Expression): arg_types = {"this": True, "value": True} -class AfterJournalProperty(Property): - arg_types = {"no": True, "dual": False, "local": False} - - class AlgorithmProperty(Property): arg_types = {"this": True} @@ -1706,7 +1840,13 @@ class CollateProperty(Property): class DataBlocksizeProperty(Property): - arg_types = {"size": False, "units": False, "min": False, "default": False} + arg_types = { + "size": False, + "units": False, + "minimum": False, + "maximum": False, + "default": False, + } class DefinerProperty(Property): @@ -1760,7 +1900,13 @@ class IsolatedLoadingProperty(Property): class JournalProperty(Property): - arg_types = {"no": True, "dual": False, "before": False} + arg_types = { + "no": False, + "dual": False, + "before": False, + "local": False, + "after": False, + } class LanguageProperty(Property): @@ -1798,11 +1944,11 @@ class MergeBlockRatioProperty(Property): class NoPrimaryIndexProperty(Property): - arg_types = {"this": False} + arg_types = {} class OnCommitProperty(Property): - arg_type = {"this": False} + arg_type = {"delete": False} class PartitionedByProperty(Property): @@ -1846,6 +1992,10 @@ class SetProperty(Property): arg_types = {"multi": True} +class SettingsProperty(Property): + arg_types = {"expressions": True} + + class SortKeyProperty(Property): arg_types = {"this": True, "compound": False} @@ -1858,12 +2008,8 @@ class StabilityProperty(Property): arg_types = {"this": True} -class TableFormatProperty(Property): - arg_types = {"this": True} - - class TemporaryProperty(Property): - arg_types = {"global_": True} + arg_types = {} class TransientProperty(Property): @@ -1903,7 +2049,6 @@ class Properties(Expression): "RETURNS": ReturnsProperty, "ROW_FORMAT": RowFormatProperty, "SORTKEY": SortKeyProperty, - "TABLE_FORMAT": TableFormatProperty, } PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} @@ -1932,7 +2077,7 @@ class Properties(Expression): UNSUPPORTED = auto() @classmethod - def from_dict(cls, properties_dict) -> Properties: + def from_dict(cls, properties_dict: t.Dict) -> Properties: expressions = [] for key, value in properties_dict.items(): property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) @@ -1961,7 +2106,7 @@ class Tuple(Expression): arg_types = {"expressions": False} def isin( - self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts + self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts ) -> In: return In( this=_maybe_copy(self, copy), @@ -1971,7 +2116,7 @@ class Tuple(Expression): class Subqueryable(Unionable): - def subquery(self, alias=None, copy=True) -> Subquery: + def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery: """ Convert this expression to an aliased expression that can be used as a Subquery. @@ -1988,12 +2133,14 @@ class Subqueryable(Unionable): Alias: the subquery """ instance = _maybe_copy(self, copy) - return Subquery( - this=instance, - alias=TableAlias(this=to_identifier(alias)) if alias else None, - ) + if not isinstance(alias, Expression): + alias = TableAlias(this=to_identifier(alias)) if alias else None + + return Subquery(this=instance, alias=alias) - def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: raise NotImplementedError @property @@ -2013,14 +2160,14 @@ class Subqueryable(Unionable): def with_( self, - alias, - as_, - recursive=None, - append=True, - dialect=None, - copy=True, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, **opts, - ): + ) -> Subqueryable: """ Append to or set the common table expressions. @@ -2029,43 +2176,22 @@ class Subqueryable(Unionable): 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' Args: - alias (str | Expression): the SQL code string to parse as the table name. + alias: the SQL code string to parse as the table name. If an `Expression` instance is passed, this is used as-is. - as_ (str | Expression): the SQL code string to parse as the table expression. + as_: the SQL code string to parse as the table expression. If an `Expression` instance is passed, it will be used as-is. - recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. - append (bool): if `True`, add to any existing expressions. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified expression. """ - alias_expression = maybe_parse( - alias, - dialect=dialect, - into=TableAlias, - **opts, - ) - as_expression = maybe_parse( - as_, - dialect=dialect, - **opts, - ) - cte = CTE( - this=as_expression, - alias=alias_expression, - ) - return _apply_child_list_builder( - cte, - instance=self, - arg="with", - append=append, - copy=copy, - into=With, - properties={"recursive": recursive or False}, + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts ) @@ -2085,8 +2211,10 @@ QUERY_MODIFIERS = { "order": False, "limit": False, "offset": False, - "lock": False, + "locks": False, "sample": False, + "settings": False, + "format": False, } @@ -2111,6 +2239,15 @@ class Table(Expression): def catalog(self) -> str: return self.text("catalog") + @property + def parts(self) -> t.List[Identifier]: + """Return the parts of a table in order catalog, db, table.""" + return [ + t.cast(Identifier, self.args[part]) + for part in ("catalog", "db", "this") + if self.args.get(part) + ] + # See the TSQL "Querying data in a system-versioned temporal table" page class SystemTime(Expression): @@ -2130,7 +2267,9 @@ class Union(Subqueryable): **QUERY_MODIFIERS, } - def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the LIMIT expression. @@ -2139,16 +2278,16 @@ class Union(Subqueryable): 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' Args: - expression (str | int | Expression): the SQL code string to parse. + expression: the SQL code string to parse. This can also be an integer. If a `Limit` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: The limited subqueryable. + The limited subqueryable. """ return ( select("*") @@ -2158,7 +2297,7 @@ class Union(Subqueryable): def select( self, - *expressions: ExpOrStr, + *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2255,10 +2394,10 @@ class Schema(Expression): arg_types = {"this": False, "expressions": False} -# Used to represent the FOR UPDATE and FOR SHARE locking read types. -# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html +# https://dev.mysql.com/doc/refman/8.0/en/select.html +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/SELECT.html class Lock(Expression): - arg_types = {"update": True} + arg_types = {"update": True, "expressions": False, "wait": False} class Select(Subqueryable): @@ -2275,7 +2414,9 @@ class Select(Subqueryable): **QUERY_MODIFIERS, } - def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def from_( + self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the FROM expression. @@ -2284,31 +2425,35 @@ class Select(Subqueryable): 'SELECT x FROM tbl' Args: - *expressions (str | Expression): the SQL code strings to parse. + expression : the SQL code strings to parse. If a `From` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `From`. - append (bool): if `True`, add to any existing expressions. - Otherwise, this flattens all the `From` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ - return _apply_child_list_builder( - *expressions, + return _apply_builder( + expression=expression, instance=self, arg="from", - append=append, - copy=copy, - prefix="FROM", into=From, + prefix="FROM", dialect=dialect, + copy=copy, **opts, ) - def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def group_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the GROUP BY expression. @@ -2317,21 +2462,22 @@ class Select(Subqueryable): 'SELECT x, COUNT(1) FROM tbl GROUP BY x' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Group`. If nothing is passed in then a group by is not applied to the expression - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Group` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ if not expressions: return self if not copy else self.copy() + return _apply_child_list_builder( *expressions, instance=self, @@ -2344,7 +2490,14 @@ class Select(Subqueryable): **opts, ) - def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def order_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the ORDER BY expression. @@ -2353,17 +2506,17 @@ class Select(Subqueryable): 'SELECT x FROM tbl ORDER BY x DESC' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Order`. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Order` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_child_list_builder( *expressions, @@ -2377,26 +2530,33 @@ class Select(Subqueryable): **opts, ) - def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def sort_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the SORT BY expression. Example: - >>> Select().from_("tbl").select("x").sort_by("x DESC").sql() + >>> Select().from_("tbl").select("x").sort_by("x DESC").sql(dialect="hive") 'SELECT x FROM tbl SORT BY x DESC' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `SORT`. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Order` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_child_list_builder( *expressions, @@ -2410,26 +2570,33 @@ class Select(Subqueryable): **opts, ) - def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def cluster_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the CLUSTER BY expression. Example: - >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql() + >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql(dialect="hive") 'SELECT x FROM tbl CLUSTER BY x DESC' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Cluster`. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Order` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_child_list_builder( *expressions, @@ -2443,7 +2610,9 @@ class Select(Subqueryable): **opts, ) - def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the LIMIT expression. @@ -2452,13 +2621,13 @@ class Select(Subqueryable): 'SELECT x FROM tbl LIMIT 10' Args: - expression (str | int | Expression): the SQL code string to parse. + expression: the SQL code string to parse. This can also be an integer. If a `Limit` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2474,7 +2643,9 @@ class Select(Subqueryable): **opts, ) - def offset(self, expression, dialect=None, copy=True, **opts) -> Select: + def offset( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the OFFSET expression. @@ -2483,16 +2654,16 @@ class Select(Subqueryable): 'SELECT x FROM tbl OFFSET 10' Args: - expression (str | int | Expression): the SQL code string to parse. + expression: the SQL code string to parse. This can also be an integer. If a `Offset` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Offset`. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_builder( expression=expression, @@ -2507,7 +2678,7 @@ class Select(Subqueryable): def select( self, - *expressions: ExpOrStr, + *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2530,7 +2701,7 @@ class Select(Subqueryable): opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_list_builder( *expressions, @@ -2542,7 +2713,14 @@ class Select(Subqueryable): **opts, ) - def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def lateral( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the LATERAL expressions. @@ -2551,16 +2729,16 @@ class Select(Subqueryable): 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_list_builder( *expressions, @@ -2576,14 +2754,14 @@ class Select(Subqueryable): def join( self, - expression, - on=None, - using=None, - append=True, - join_type=None, - join_alias=None, - dialect=None, - copy=True, + expression: ExpOrStr, + on: t.Optional[ExpOrStr] = None, + using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None, + append: bool = True, + join_type: t.Optional[str] = None, + join_alias: t.Optional[Identifier | str] = None, + dialect: DialectType = None, + copy: bool = True, **opts, ) -> Select: """ @@ -2602,18 +2780,19 @@ class Select(Subqueryable): 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' Args: - expression (str | Expression): the SQL code string to parse. + expression: the SQL code string to parse. If an `Expression` instance is passed, it will be used as-is. - on (str | Expression): optionally specify the join "on" criteria as a SQL string. + on: optionally specify the join "on" criteria as a SQL string. If an `Expression` instance is passed, it will be used as-is. - using (str | Expression): optionally specify the join "using" criteria as a SQL string. + using: optionally specify the join "using" criteria as a SQL string. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - join_type (str): If set, alter the parsed join type - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + join_type: if set, alter the parsed join type. + join_alias: an optional alias for the joined source. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2621,9 +2800,9 @@ class Select(Subqueryable): parse_args = {"dialect": dialect, **opts} try: - expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) + expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore except ParseError: - expression = maybe_parse(expression, into=(Join, Expression), **parse_args) + expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore join = expression if isinstance(expression, Join) else Join(this=expression) @@ -2645,12 +2824,12 @@ class Select(Subqueryable): join.set("kind", kind.text) if on: - on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts) + on = and_(*ensure_list(on), dialect=dialect, copy=copy, **opts) join.set("on", on) if using: join = _apply_list_builder( - *ensure_collection(using), + *ensure_list(using), instance=join, arg="using", append=append, @@ -2660,6 +2839,7 @@ class Select(Subqueryable): if join_alias: join.set("this", alias_(join.this, join_alias, table=True)) + return _apply_list_builder( join, instance=self, @@ -2669,7 +2849,14 @@ class Select(Subqueryable): **opts, ) - def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def where( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the WHERE expressions. @@ -2678,14 +2865,14 @@ class Select(Subqueryable): "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. - append (bool): if `True`, AND the new expressions to any existing expression. + append: if `True`, AND the new expressions to any existing expression. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2701,7 +2888,14 @@ class Select(Subqueryable): **opts, ) - def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def having( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the HAVING expressions. @@ -2710,17 +2904,17 @@ class Select(Subqueryable): 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. - append (bool): if `True`, AND the new expressions to any existing expression. + append: if `True`, AND the new expressions to any existing expression. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_conjunction_builder( *expressions, @@ -2733,7 +2927,14 @@ class Select(Subqueryable): **opts, ) - def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def window( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: return _apply_list_builder( *expressions, instance=self, @@ -2745,7 +2946,14 @@ class Select(Subqueryable): **opts, ) - def qualify(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def qualify( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: return _apply_conjunction_builder( *expressions, instance=self, @@ -2757,7 +2965,9 @@ class Select(Subqueryable): **opts, ) - def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select: + def distinct( + self, *ons: t.Optional[ExpOrStr], distinct: bool = True, copy: bool = True + ) -> Select: """ Set the OFFSET expression. @@ -2774,11 +2984,18 @@ class Select(Subqueryable): Select: the modified expression. """ instance = _maybe_copy(self, copy) - on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None + on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None instance.set("distinct", Distinct(on=on) if distinct else None) return instance - def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create: + def ctas( + self, + table: ExpOrStr, + properties: t.Optional[t.Dict] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Create: """ Convert this expression to a CREATE TABLE AS statement. @@ -2787,15 +3004,15 @@ class Select(Subqueryable): 'CREATE TABLE x AS SELECT * FROM tbl' Args: - table (str | Expression): the SQL code string to parse as the table name. + table: the SQL code string to parse as the table name. If another `Expression` instance is passed, it will be used as-is. - properties (dict): an optional mapping of table properties - dialect (str): the dialect used to parse the input table. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input table. + properties: an optional mapping of table properties + dialect: the dialect used to parse the input table. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input table. Returns: - Create: the CREATE TABLE AS expression + The new Create expression. """ instance = _maybe_copy(self, copy) table_expression = maybe_parse( @@ -2835,7 +3052,7 @@ class Select(Subqueryable): """ inst = _maybe_copy(self, copy) - inst.set("lock", Lock(update=update)) + inst.set("locks", [Lock(update=update)]) return inst @@ -2874,7 +3091,7 @@ class Subquery(DerivedTable, Unionable): return self.this.is_star @property - def output_name(self): + def output_name(self) -> str: return self.alias @@ -2903,13 +3120,17 @@ class Tag(Expression): } +# Represents both the standard SQL PIVOT operator and DuckDB's "simplified" PIVOT syntax +# https://duckdb.org/docs/sql/statements/pivot class Pivot(Expression): arg_types = { "this": False, "alias": False, "expressions": True, - "field": True, - "unpivot": True, + "field": False, + "unpivot": False, + "using": False, + "group": False, "columns": False, } @@ -2948,7 +3169,7 @@ class Star(Expression): return "*" @property - def output_name(self): + def output_name(self) -> str: return self.name @@ -2961,7 +3182,7 @@ class SessionParameter(Expression): class Placeholder(Expression): - arg_types = {"this": False} + arg_types = {"this": False, "kind": False} class Null(Condition): @@ -2976,6 +3197,10 @@ class Boolean(Condition): pass +class DataTypeSize(Expression): + arg_types = {"this": True, "expression": False} + + class DataType(Expression): arg_types = { "this": True, @@ -2986,68 +3211,69 @@ class DataType(Expression): } class Type(AutoName): - CHAR = auto() - NCHAR = auto() - VARCHAR = auto() - NVARCHAR = auto() - TEXT = auto() - MEDIUMTEXT = auto() - LONGTEXT = auto() - MEDIUMBLOB = auto() - LONGBLOB = auto() - BINARY = auto() - VARBINARY = auto() - INT = auto() - UINT = auto() - TINYINT = auto() - UTINYINT = auto() - SMALLINT = auto() - USMALLINT = auto() - BIGINT = auto() - UBIGINT = auto() - INT128 = auto() - UINT128 = auto() - INT256 = auto() - UINT256 = auto() - FLOAT = auto() - DOUBLE = auto() - DECIMAL = auto() + ARRAY = auto() BIGDECIMAL = auto() + BIGINT = auto() + BIGSERIAL = auto() + BINARY = auto() BIT = auto() BOOLEAN = auto() - JSON = auto() - JSONB = auto() - INTERVAL = auto() - TIME = auto() - TIMESTAMP = auto() - TIMESTAMPTZ = auto() - TIMESTAMPLTZ = auto() + CHAR = auto() DATE = auto() DATETIME = auto() - ARRAY = auto() - MAP = auto() - UUID = auto() + DATETIME64 = auto() + DECIMAL = auto() + DOUBLE = auto() + FLOAT = auto() GEOGRAPHY = auto() GEOMETRY = auto() - STRUCT = auto() - NULLABLE = auto() HLLSKETCH = auto() HSTORE = auto() - SUPER = auto() - SERIAL = auto() - SMALLSERIAL = auto() - BIGSERIAL = auto() - XML = auto() - UNIQUEIDENTIFIER = auto() - MONEY = auto() - SMALLMONEY = auto() - ROWVERSION = auto() IMAGE = auto() - VARIANT = auto() - OBJECT = auto() INET = auto() + INT = auto() + INT128 = auto() + INT256 = auto() + INTERVAL = auto() + JSON = auto() + JSONB = auto() + LONGBLOB = auto() + LONGTEXT = auto() + MAP = auto() + MEDIUMBLOB = auto() + MEDIUMTEXT = auto() + MONEY = auto() + NCHAR = auto() NULL = auto() + NULLABLE = auto() + NVARCHAR = auto() + OBJECT = auto() + ROWVERSION = auto() + SERIAL = auto() + SMALLINT = auto() + SMALLMONEY = auto() + SMALLSERIAL = auto() + STRUCT = auto() + SUPER = auto() + TEXT = auto() + TIME = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + TIMESTAMPLTZ = auto() + TINYINT = auto() + UBIGINT = auto() + UINT = auto() + USMALLINT = auto() + UTINYINT = auto() UNKNOWN = auto() # Sentinel value, useful for type annotation + UINT128 = auto() + UINT256 = auto() + UNIQUEIDENTIFIER = auto() + UUID = auto() + VARBINARY = auto() + VARCHAR = auto() + VARIANT = auto() + XML = auto() TEXT_TYPES = { Type.CHAR, @@ -3079,6 +3305,7 @@ class DataType(Expression): Type.TIMESTAMPLTZ, Type.DATE, Type.DATETIME, + Type.DATETIME64, } @classmethod @@ -3092,6 +3319,7 @@ class DataType(Expression): data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) + if data_type_exp is None: raise ValueError(f"Unparsable data type value: {dtype}") elif isinstance(dtype, DataType.Type): @@ -3100,6 +3328,7 @@ class DataType(Expression): return dtype else: raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") + return DataType(**{**data_type_exp.args, **kwargs}) def is_type(self, dtype: DataType.Type) -> bool: @@ -3361,7 +3590,7 @@ class Alias(Expression): arg_types = {"this": True, "alias": False} @property - def output_name(self): + def output_name(self) -> str: return self.alias @@ -3411,12 +3640,17 @@ class TimeUnit(Expression): args["unit"] = Var(this=unit.name) elif isinstance(unit, Week): unit.set("this", Var(this=unit.this.name)) + super().__init__(**args) class Interval(TimeUnit): arg_types = {"this": False, "unit": False} + @property + def unit(self) -> t.Optional[Var]: + return self.args.get("unit") + class IgnoreNulls(Expression): pass @@ -3480,6 +3714,10 @@ class AggFunc(Func): pass +class ParameterizedAgg(AggFunc): + arg_types = {"this": True, "expressions": True, "params": True} + + class Abs(Func): pass @@ -3498,6 +3736,7 @@ class Hll(AggFunc): class ApproxDistinct(AggFunc): arg_types = {"this": True, "accuracy": False} + _sql_names = ["APPROX_DISTINCT", "APPROX_COUNT_DISTINCT"] class Array(Func): @@ -3600,17 +3839,21 @@ class Cast(Func): return self.this.name @property - def to(self): + def to(self) -> DataType: return self.args["to"] @property - def output_name(self): + def output_name(self) -> str: return self.name def is_type(self, dtype: DataType.Type) -> bool: return self.to.is_type(dtype) +class CastToStrType(Func): + arg_types = {"this": True, "expression": True} + + class Collate(Binary): pass @@ -3796,10 +4039,6 @@ class Explode(Func): pass -class ExponentialTimeDecayedAvg(AggFunc): - arg_types = {"this": True, "time": False, "decay": False} - - class Floor(Func): arg_types = {"this": True, "decimals": False} @@ -3821,18 +4060,10 @@ class GroupConcat(Func): arg_types = {"this": True, "separator": False} -class GroupUniqArray(AggFunc): - arg_types = {"this": True, "size": False} - - class Hex(Func): pass -class Histogram(AggFunc): - arg_types = {"this": True, "bins": False} - - class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -3843,7 +4074,7 @@ class IfNull(Func): class Initcap(Func): - pass + arg_types = {"this": True, "expression": False} class JSONKeyValue(Expression): @@ -3861,6 +4092,14 @@ class JSONObject(Func): } +class OpenJSONColumnDef(Expression): + arg_types = {"this": True, "kind": True, "path": False, "as_json": False} + + +class OpenJSON(Func): + arg_types = {"this": True, "path": False, "expressions": False} + + class JSONBContains(Binary): _sql_names = ["JSONB_CONTAINS"] @@ -3945,6 +4184,14 @@ class VarMap(Func): arg_types = {"keys": True, "values": True} is_var_len_args = True + @property + def keys(self) -> t.List[Expression]: + return self.args["keys"].expressions + + @property + def values(self) -> t.List[Expression]: + return self.args["values"].expressions + # https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html class MatchAgainst(Func): @@ -3993,17 +4240,6 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} -# Clickhouse-specific: -# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles -class Quantiles(AggFunc): - arg_types = {"parameters": True, "expressions": True} - is_var_len_args = True - - -class QuantileIf(AggFunc): - arg_types = {"parameters": True, "expressions": True} - - class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} @@ -4089,6 +4325,10 @@ class Substring(Func): arg_types = {"this": True, "start": False, "length": False} +class StandardHash(Func): + arg_types = {"this": True, "expression": False} + + class StrPosition(Func): arg_types = { "this": True, @@ -4328,15 +4568,19 @@ def maybe_parse( return sql_or_expression.copy() return sql_or_expression + if sql_or_expression is None: + raise ParseError(f"SQL cannot be None") + import sqlglot sql = str(sql_or_expression) if prefix: sql = f"{prefix} {sql}" + return sqlglot.parse_one(sql, read=dialect, into=into, **opts) -def _maybe_copy(instance, copy=True): +def _maybe_copy(instance: E, copy: bool = True) -> E: return instance.copy() if copy else instance @@ -4383,16 +4627,18 @@ def _apply_child_list_builder( instance = _maybe_copy(instance, copy) parsed = [] for expression in expressions: - if _is_wrong_expression(expression, into): - expression = into(expressions=[expression]) - expression = maybe_parse( - expression, - into=into, - dialect=dialect, - prefix=prefix, - **opts, - ) - parsed.extend(expression.expressions) + if expression is not None: + if _is_wrong_expression(expression, into): + expression = into(expressions=[expression]) + + expression = maybe_parse( + expression, + into=into, + dialect=dialect, + prefix=prefix, + **opts, + ) + parsed.extend(expression.expressions) existing = instance.args.get(arg) if append and existing: @@ -4402,6 +4648,7 @@ def _apply_child_list_builder( for k, v in (properties or {}).items(): child.set(k, v) instance.set(arg, child) + return instance @@ -4427,6 +4674,7 @@ def _apply_list_builder( **opts, ) for expression in expressions + if expression is not None ] existing_expressions = inst.args.get(arg) @@ -4463,25 +4711,59 @@ def _apply_conjunction_builder( return inst -def _combine(expressions, operator, dialect=None, copy=True, **opts): - expressions = [ - condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions +def _apply_cte_builder( + instance: E, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> E: + alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts) + as_expression = maybe_parse(as_, dialect=dialect, **opts) + cte = CTE(this=as_expression, alias=alias_expression) + return _apply_child_list_builder( + cte, + instance=instance, + arg="with", + append=append, + copy=copy, + into=With, + properties={"recursive": recursive or False}, + ) + + +def _combine( + expressions: t.Sequence[t.Optional[ExpOrStr]], + operator: t.Type[Connector], + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Expression: + conditions = [ + condition(expression, dialect=dialect, copy=copy, **opts) + for expression in expressions + if expression is not None ] - this = expressions[0] - if expressions[1:]: + + this, *rest = conditions + if rest: this = _wrap(this, Connector) - for expression in expressions[1:]: + for expression in rest: this = operator(this=this, expression=_wrap(expression, Connector)) + return this def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: - if isinstance(expression, kind): - return Paren(this=expression) - return expression + return Paren(this=expression) if isinstance(expression, kind) else expression -def union(left, right, distinct=True, dialect=None, **opts): +def union( + left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts +) -> Union: """ Initializes a syntax tree from one UNION expression. @@ -4490,15 +4772,16 @@ def union(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - left (str | Expression): the SQL code string corresponding to the left-hand side. + left: the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str | Expression): the SQL code string corresponding to the right-hand side. + right: the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Union: the syntax tree for the UNION expression. + The new Union instance. """ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) @@ -4506,7 +4789,9 @@ def union(left, right, distinct=True, dialect=None, **opts): return Union(this=left, expression=right, distinct=distinct) -def intersect(left, right, distinct=True, dialect=None, **opts): +def intersect( + left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts +) -> Intersect: """ Initializes a syntax tree from one INTERSECT expression. @@ -4515,15 +4800,16 @@ def intersect(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - left (str | Expression): the SQL code string corresponding to the left-hand side. + left: the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str | Expression): the SQL code string corresponding to the right-hand side. + right: the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Intersect: the syntax tree for the INTERSECT expression. + The new Intersect instance. """ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) @@ -4531,7 +4817,9 @@ def intersect(left, right, distinct=True, dialect=None, **opts): return Intersect(this=left, expression=right, distinct=distinct) -def except_(left, right, distinct=True, dialect=None, **opts): +def except_( + left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts +) -> Except: """ Initializes a syntax tree from one EXCEPT expression. @@ -4540,15 +4828,16 @@ def except_(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - left (str | Expression): the SQL code string corresponding to the left-hand side. + left: the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str | Expression): the SQL code string corresponding to the right-hand side. + right: the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Except: the syntax tree for the EXCEPT statement. + The new Except instance. """ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) @@ -4578,7 +4867,7 @@ def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Selec return Select().select(*expressions, dialect=dialect, **opts) -def from_(*expressions, dialect=None, **opts) -> Select: +def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select: """ Initializes a syntax tree from a FROM expression. @@ -4587,9 +4876,9 @@ def from_(*expressions, dialect=None, **opts) -> Select: 'SELECT col1, col2 FROM tbl' Args: - *expressions (str | Expression): the SQL code string to parse as the FROM expressions of a + *expression: the SQL code string to parse as the FROM expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression (in the case that the + dialect: the dialect used to parse the input expression (in the case that the input expression is a SQL string). **opts: other options to use to parse the input expressions (again, in the case that the input expression is a SQL string). @@ -4597,7 +4886,7 @@ def from_(*expressions, dialect=None, **opts) -> Select: Returns: Select: the syntax tree for the SELECT statement. """ - return Select().from_(*expressions, dialect=dialect, **opts) + return Select().from_(expression, dialect=dialect, **opts) def update( @@ -4680,7 +4969,54 @@ def delete( return delete_expr -def condition(expression, dialect=None, copy=True, **opts) -> Condition: +def insert( + expression: ExpOrStr, + into: ExpOrStr, + columns: t.Optional[t.Sequence[ExpOrStr]] = None, + overwrite: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Insert: + """ + Builds an INSERT statement. + + Example: + >>> insert("VALUES (1, 2, 3)", "tbl").sql() + 'INSERT INTO tbl VALUES (1, 2, 3)' + + Args: + expression: the sql string or expression of the INSERT statement + into: the tbl to insert data to. + columns: optionally the table's column names. + overwrite: whether to INSERT OVERWRITE or not. + dialect: the dialect used to parse the input expressions. + copy: whether or not to copy the expression. + **opts: other options to use to parse the input expressions. + + Returns: + Insert: the syntax tree for the INSERT statement. + """ + expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts) + this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts) + + if columns: + this = _apply_list_builder( + *columns, + instance=Schema(this=this), + arg="expressions", + into=Identifier, + copy=False, + dialect=dialect, + **opts, + ) + + return Insert(this=this, expression=expr, overwrite=overwrite) + + +def condition( + expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: """ Initialize a logical condition expression. @@ -4695,18 +5031,18 @@ def condition(expression, dialect=None, copy=True, **opts) -> Condition: 'SELECT * FROM tbl WHERE x = 1 AND y = 1' Args: - *expression (str | Expression): the SQL code string to parse. + *expression: the SQL code string to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression (in the case that the + dialect: the dialect used to parse the input expression (in the case that the input expression is a SQL string). - copy (bool): Whether or not to copy `expression` (only applies to expressions). + copy: Whether or not to copy `expression` (only applies to expressions). **opts: other options to use to parse the input expressions (again, in the case that the input expression is a SQL string). Returns: - Condition: the expression + The new Condition instance """ - return maybe_parse( # type: ignore + return maybe_parse( expression, into=Condition, dialect=dialect, @@ -4715,7 +5051,9 @@ def condition(expression, dialect=None, copy=True, **opts) -> Condition: ) -def and_(*expressions, dialect=None, copy=True, **opts) -> And: +def and_( + *expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: """ Combine multiple conditions with an AND logical operator. @@ -4724,19 +5062,21 @@ def and_(*expressions, dialect=None, copy=True, **opts) -> And: 'x = 1 AND (y = 1 AND z = 1)' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy `expressions` (only applies to Expressions). + dialect: the dialect used to parse the input expression. + copy: whether or not to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: And: the new condition """ - return _combine(expressions, And, dialect, copy=copy, **opts) + return t.cast(Condition, _combine(expressions, And, dialect, copy=copy, **opts)) -def or_(*expressions, dialect=None, copy=True, **opts) -> Or: +def or_( + *expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: """ Combine multiple conditions with an OR logical operator. @@ -4745,19 +5085,19 @@ def or_(*expressions, dialect=None, copy=True, **opts) -> Or: 'x = 1 OR (y = 1 OR z = 1)' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy `expressions` (only applies to Expressions). + dialect: the dialect used to parse the input expression. + copy: whether or not to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: Or: the new condition """ - return _combine(expressions, Or, dialect, copy=copy, **opts) + return t.cast(Condition, _combine(expressions, Or, dialect, copy=copy, **opts)) -def not_(expression, dialect=None, copy=True, **opts) -> Not: +def not_(expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts) -> Not: """ Wrap a condition with a NOT operator. @@ -4766,13 +5106,14 @@ def not_(expression, dialect=None, copy=True, **opts) -> Not: "NOT this_suit = 'black'" Args: - expression (str | Expression): the SQL code strings to parse. + expression: the SQL code string to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression. + dialect: the dialect used to parse the input expression. + copy: whether to copy the expression or not. **opts: other options to use to parse the input expressions. Returns: - Not: the new condition + The new condition. """ this = condition( expression, @@ -4783,29 +5124,47 @@ def not_(expression, dialect=None, copy=True, **opts) -> Not: return Not(this=_wrap(this, Connector)) -def paren(expression, copy=True) -> Paren: - return Paren(this=_maybe_copy(expression, copy)) +def paren(expression: ExpOrStr, copy: bool = True) -> Paren: + """ + Wrap an expression in parentheses. + + Example: + >>> paren("5 + 3").sql() + '(5 + 3)' + + Args: + expression: the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + copy: whether to copy the expression or not. + + Returns: + The wrapped expression. + """ + return Paren(this=maybe_parse(expression, copy=copy)) SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") @t.overload -def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None: +def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ... @t.overload -def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier: +def to_identifier( + name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True +) -> Identifier: ... -def to_identifier(name, quoted=None): +def to_identifier(name, quoted=None, copy=True): """Builds an identifier. Args: name: The name to turn into an identifier. quoted: Whether or not force quote the identifier. + copy: Whether or not to copy a passed in Identefier node. Returns: The identifier ast node. @@ -4815,7 +5174,7 @@ def to_identifier(name, quoted=None): return None if isinstance(name, Identifier): - identifier = name + identifier = _maybe_copy(name, copy) elif isinstance(name, str): identifier = Identifier( this=name, @@ -4858,13 +5217,17 @@ def to_table(sql_path: None, **kwargs) -> None: ... -def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: +def to_table( + sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs +) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. If a table is passed in then that table is returned. Args: sql_path: a `[catalog].[schema].[table]` string. + dialect: the source dialect according to which the table name will be parsed. + kwargs: the kwargs to instantiate the resulting `Table` expression with. Returns: A table expression. @@ -4874,8 +5237,12 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3)) - return Table(this=table_name, db=db, catalog=catalog, **kwargs) + table = maybe_parse(sql_path, into=Table, dialect=dialect) + if table: + for k, v in kwargs.items(): + table.set(k, v) + + return table def to_column(sql_path: str | Column, **kwargs) -> Column: @@ -4902,6 +5269,7 @@ def alias_( table: bool | t.Sequence[str | Identifier] = False, quoted: t.Optional[bool] = None, dialect: DialectType = None, + copy: bool = True, **opts, ): """Create an Alias expression. @@ -4921,18 +5289,17 @@ def alias_( table: Whether or not to create a table alias, can also be a list of columns. quoted: whether or not to quote the alias dialect: the dialect used to parse the input expression. + copy: Whether or not to copy the expression. **opts: other options to use to parse the input expressions. Returns: Alias: the aliased expression """ - exp = maybe_parse(expression, dialect=dialect, **opts) + exp = maybe_parse(expression, dialect=dialect, copy=copy, **opts) alias = to_identifier(alias, quoted=quoted) if table: table_alias = TableAlias(this=alias) - - exp = exp.copy() if isinstance(expression, Expression) else exp exp.set("alias", table_alias) if not isinstance(table, bool): @@ -4948,13 +5315,17 @@ def alias_( # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls if "alias" in exp.arg_types and not isinstance(exp, Window): - exp = exp.copy() exp.set("alias", alias) return exp return Alias(this=exp, alias=alias) -def subquery(expression, alias=None, dialect=None, **opts): +def subquery( + expression: ExpOrStr, + alias: t.Optional[Identifier | str] = None, + dialect: DialectType = None, + **opts, +) -> Select: """ Build a subquery expression. @@ -4963,14 +5334,14 @@ def subquery(expression, alias=None, dialect=None, **opts): 'SELECT x FROM (SELECT x FROM tbl) AS bar' Args: - expression (str | Expression): the SQL code strings to parse. + expression: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str | Expression): the alias name to use. - dialect (str): the dialect used to parse the input expression. + alias: the alias name to use. + dialect: the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. Returns: - Select: a new select with the subquery expression included + A new Select instance with the subquery expression included. """ expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias) @@ -4988,13 +5359,14 @@ def column( Build a Column. Args: - col: column name - table: table name - db: db name - catalog: catalog name - quoted: whether or not to force quote each part + col: Column name. + table: Table name. + db: Database name. + catalog: Catalog name. + quoted: Whether to force quotes on the column's identifiers. + Returns: - Column: column instance + The new Column instance. """ return Column( this=to_identifier(col, quoted=quoted), @@ -5016,22 +5388,30 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca to: The datatype to cast to. Returns: - A cast node. + The new Cast instance. """ expression = maybe_parse(expression, **opts) return Cast(this=expression, to=DataType.build(to, **opts)) -def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: +def table_( + table: Identifier | str, + db: t.Optional[Identifier | str] = None, + catalog: t.Optional[Identifier | str] = None, + quoted: t.Optional[bool] = None, + alias: t.Optional[Identifier | str] = None, +) -> Table: """Build a Table. Args: - table (str | Expression): column name - db (str | Expression): db name - catalog (str | Expression): catalog name + table: Table name. + db: Database name. + catalog: Catalog name. + quote: Whether to force quotes on the table's identifiers. + alias: Table's alias. Returns: - Table: table instance + The new Table instance. """ return Table( this=to_identifier(table, quoted=quoted), @@ -5160,7 +5540,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression: raise ValueError(f"Cannot convert {value}") -def replace_children(expression, fun, *args, **kwargs): +def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -> None: """ Replace children of an expression with the result of a lambda fun(child) -> exp. """ @@ -5182,7 +5562,7 @@ def replace_children(expression, fun, *args, **kwargs): expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) -def column_table_names(expression): +def column_table_names(expression: Expression) -> t.List[str]: """ Return all table names referenced through columns in an expression. @@ -5192,19 +5572,19 @@ def column_table_names(expression): ['c', 'a'] Args: - expression (sqlglot.Expression): expression to find table names + expression: expression to find table names. Returns: - list: A list of unique names + A list of unique names. """ return list(dict.fromkeys(column.table for column in expression.find_all(Column))) -def table_name(table) -> str: +def table_name(table: Table | str) -> str: """Get the full name of a table as a string. Args: - table (exp.Table | str): table expression node or string. + table: table expression node or string. Examples: >>> from sqlglot import exp, parse_one @@ -5220,23 +5600,15 @@ def table_name(table) -> str: if not table: raise ValueError(f"Cannot parse {table}") - return ".".join( - part - for part in ( - table.text("catalog"), - table.text("db"), - table.name, - ) - if part - ) + return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part) -def replace_tables(expression, mapping): +def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E: """Replace all tables in expression according to the mapping. Args: - expression (sqlglot.Expression): expression node to be transformed and replaced. - mapping (Dict[str, str]): mapping of table names. + expression: expression node to be transformed and replaced. + mapping: mapping of table names. Examples: >>> from sqlglot import exp, parse_one @@ -5247,7 +5619,7 @@ def replace_tables(expression, mapping): The mapped expression. """ - def _replace_tables(node): + def _replace_tables(node: Expression) -> Expression: if isinstance(node, Table): new_name = mapping.get(table_name(node)) if new_name: @@ -5260,11 +5632,11 @@ def replace_tables(expression, mapping): return expression.transform(_replace_tables) -def replace_placeholders(expression, *args, **kwargs): +def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: """Replace placeholders in an expression. Args: - expression (sqlglot.Expression): expression node to be transformed and replaced. + expression: expression node to be transformed and replaced. args: positional names that will substitute unnamed placeholders in the given order. kwargs: keyword arguments that will substitute named placeholders. @@ -5280,7 +5652,7 @@ def replace_placeholders(expression, *args, **kwargs): The mapped expression. """ - def _replace_placeholders(node, args, **kwargs): + def _replace_placeholders(node: Expression, args, **kwargs) -> Expression: if isinstance(node, Placeholder): if node.name: new_name = kwargs.get(node.name) @@ -5378,21 +5750,21 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: return function -def true(): +def true() -> Boolean: """ Returns a true Boolean expression. """ return Boolean(this=True) -def false(): +def false() -> Boolean: """ Returns a false Boolean expression. """ return Boolean(this=False) -def null(): +def null() -> Null: """ Returns a Null expression. """ |