From 41f1f5740d2140bfd3b2a282ca1087a4b576679a Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 15 Apr 2024 07:02:18 +0200 Subject: Merging upstream version 23.10.0. Signed-off-by: Daniel Baumann --- sqlglot/expressions.py | 354 +++++++++++++++++++++++++++---------------------- 1 file changed, 195 insertions(+), 159 deletions(-) (limited to 'sqlglot/expressions.py') diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index e79c04b..5adbb1e 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -58,6 +58,7 @@ class _Expression(type): SQLGLOT_META = "sqlglot.meta" TABLE_PARTS = ("this", "db", "catalog") +COLUMN_PARTS = ("this", "table", "db", "catalog") class Expression(metaclass=_Expression): @@ -175,6 +176,15 @@ class Expression(metaclass=_Expression): """ return isinstance(self, Literal) and not self.args["is_string"] + @property + def is_negative(self) -> bool: + """ + Checks whether an expression is negative. + + Handles both exp.Neg and Literal numbers with "-" which come from optimizer.simplify. + """ + return isinstance(self, Neg) or (self.is_number and self.this.startswith("-")) + @property def is_int(self) -> bool: """ @@ -845,10 +855,14 @@ class Expression(metaclass=_Expression): copy: bool = True, **opts, ) -> In: + subquery = maybe_parse(query, copy=copy, **opts) if query else None + if subquery and not isinstance(subquery, Subquery): + subquery = subquery.subquery(copy=False) + return In( this=maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], - query=maybe_parse(query, copy=copy, **opts) if query else None, + query=subquery, unnest=( Unnest( expressions=[ @@ -1018,14 +1032,14 @@ class Query(Expression): return Subquery(this=instance, alias=alias) def limit( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: + self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Q: """ Adds a LIMIT clause to this query. Example: >>> select("1").union(select("1")).limit(1).sql() - 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' + 'SELECT 1 UNION SELECT 1 LIMIT 1' Args: expression: the SQL code string to parse. @@ -1039,10 +1053,90 @@ class Query(Expression): Returns: A limited Select expression. """ - return ( - select("*") - .from_(self.subquery(alias="_l_0", copy=copy)) - .limit(expression, dialect=dialect, copy=False, **opts) + return _apply_builder( + expression=expression, + instance=self, + arg="limit", + into=Limit, + prefix="LIMIT", + dialect=dialect, + copy=copy, + into_arg="expression", + **opts, + ) + + def offset( + self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Q: + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").offset(10).sql() + 'SELECT x FROM tbl OFFSET 10' + + Args: + expression: the SQL code string to parse. + This can also be an integer. + If a `Offset` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Offset`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="offset", + into=Offset, + prefix="OFFSET", + dialect=dialect, + copy=copy, + into_arg="expression", + **opts, + ) + + def order_by( + self: Q, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Q: + """ + Set the ORDER BY expression. + + Example: + >>> Select().from_("tbl").select("x").order_by("x DESC").sql() + 'SELECT x FROM tbl ORDER BY x DESC' + + Args: + *expressions: the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Order`. + append: if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Select expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="order", + append=append, + copy=copy, + prefix="ORDER BY", + into=Order, + dialect=dialect, + **opts, ) @property @@ -1536,7 +1630,13 @@ class SwapTable(Expression): class Comment(Expression): - arg_types = {"this": True, "kind": True, "expression": True, "exists": False} + arg_types = { + "this": True, + "kind": True, + "expression": True, + "exists": False, + "materialized": False, + } class Comprehension(Expression): @@ -1642,6 +1742,10 @@ class ExcludeColumnConstraint(ColumnConstraintKind): pass +class EphemeralColumnConstraint(ColumnConstraintKind): + arg_types = {"this": False} + + class WithOperator(Expression): arg_types = {"this": True, "op": True} @@ -2221,6 +2325,13 @@ class Lateral(UDTF): } +class MatchRecognizeMeasure(Expression): + arg_types = { + "this": True, + "window_frame": False, + } + + class MatchRecognize(Expression): arg_types = { "partition_by": False, @@ -3051,46 +3162,6 @@ class Select(Query): **opts, ) - def order_by( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Set the ORDER BY expression. - - Example: - >>> Select().from_("tbl").select("x").order_by("x DESC").sql() - 'SELECT x FROM tbl ORDER BY x DESC' - - Args: - *expressions: the SQL code strings to parse. - If a `Group` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Order`. - append: if `True`, add to any existing expressions. - Otherwise, this flattens all the `Order` expression into a single expression. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_child_list_builder( - *expressions, - instance=self, - arg="order", - append=append, - copy=copy, - prefix="ORDER BY", - into=Order, - dialect=dialect, - **opts, - ) - def sort_by( self, *expressions: t.Optional[ExpOrStr], @@ -3171,55 +3242,6 @@ class Select(Query): **opts, ) - def limit( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - return _apply_builder( - expression=expression, - instance=self, - arg="limit", - into=Limit, - prefix="LIMIT", - dialect=dialect, - copy=copy, - into_arg="expression", - **opts, - ) - - def offset( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - """ - Set the OFFSET expression. - - Example: - >>> Select().from_("tbl").select("x").offset(10).sql() - 'SELECT x FROM tbl OFFSET 10' - - Args: - expression: the SQL code string to parse. - This can also be an integer. - If a `Offset` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Offset`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="offset", - into=Offset, - prefix="OFFSET", - dialect=dialect, - copy=copy, - into_arg="expression", - **opts, - ) - def select( self, *expressions: t.Optional[ExpOrStr], @@ -4214,7 +4236,7 @@ class Dot(Binary): parts.reverse() - for arg in ("this", "table", "db", "catalog"): + for arg in COLUMN_PARTS: part = this.args.get(arg) if isinstance(part, Expression): @@ -4395,7 +4417,13 @@ class Between(Predicate): class Bracket(Condition): # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator - arg_types = {"this": True, "expressions": True, "offset": False, "safe": False} + arg_types = { + "this": True, + "expressions": True, + "offset": False, + "safe": False, + "returns_list_for_maps": False, + } @property def output_name(self) -> str: @@ -5458,6 +5486,10 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} +class Quarter(Func): + pass + + class Rand(Func): _sql_names = ["RAND", "RANDOM"] arg_types = {"this": False} @@ -6620,17 +6652,9 @@ def to_interval(interval: str | Literal) -> Interval: ) -@t.overload -def to_table(sql_path: str | Table, **kwargs) -> Table: ... - - -@t.overload -def to_table(sql_path: None, **kwargs) -> None: ... - - def to_table( - sql_path: t.Optional[str | Table], dialect: DialectType = None, copy: bool = True, **kwargs -) -> t.Optional[Table]: + sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs +) -> Table: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. If a table is passed in then that table is returned. @@ -6644,35 +6668,54 @@ def to_table( Returns: A table expression. """ - if sql_path is None or isinstance(sql_path, Table): + if isinstance(sql_path, Table): return maybe_copy(sql_path, copy=copy) - if not isinstance(sql_path, str): - raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") table = maybe_parse(sql_path, into=Table, dialect=dialect) - if table: - for k, v in kwargs.items(): - table.set(k, v) + + for k, v in kwargs.items(): + table.set(k, v) return table -def to_column(sql_path: str | Column, **kwargs) -> Column: +def to_column( + sql_path: str | Column, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **kwargs, +) -> Column: """ - Create a column from a `[table].[column]` sql path. Schema is optional. - + Create a column from a `[table].[column]` sql path. Table is optional. If a column is passed in then that column is returned. Args: - sql_path: `[table].[column]` string + sql_path: a `[table].[column]` string. + quoted: Whether or not to force quote identifiers. + dialect: the source dialect according to which the column name will be parsed. + copy: Whether to copy a column if it is passed in. + kwargs: the kwargs to instantiate the resulting `Column` expression with. + Returns: - Table: A column expression + A column expression. """ - if sql_path is None or isinstance(sql_path, Column): - return sql_path - if not isinstance(sql_path, str): - raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore + if isinstance(sql_path, Column): + return maybe_copy(sql_path, copy=copy) + + try: + col = maybe_parse(sql_path, into=Column, dialect=dialect) + except ParseError: + return column(*reversed(sql_path.split(".")), quoted=quoted, **kwargs) + + for k, v in kwargs.items(): + col.set(k, v) + + if quoted: + for i in col.find_all(Identifier): + i.set("quoted", True) + + return col def alias_( @@ -6756,7 +6799,7 @@ def subquery( A new Select instance with the subquery expression included. """ - expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias) + expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias, **opts) return Select().from_(expression, dialect=dialect, **opts) @@ -6821,7 +6864,9 @@ def column( ) if fields: - this = Dot.build((this, *(to_identifier(field, copy=copy) for field in fields))) + this = Dot.build( + (this, *(to_identifier(field, quoted=quoted, copy=copy) for field in fields)) + ) return this @@ -6840,11 +6885,16 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast Returns: The new Cast instance. """ - expression = maybe_parse(expression, copy=copy, **opts) + expr = maybe_parse(expression, copy=copy, **opts) data_type = DataType.build(to, copy=copy, **opts) - expression = Cast(this=expression, to=data_type) - expression.type = data_type - return expression + + if expr.is_type(data_type): + return expr + + expr = Cast(this=expr, to=data_type) + expr.type = data_type + + return expr def table_( @@ -6931,18 +6981,23 @@ def var(name: t.Optional[ExpOrStr]) -> Var: return Var(this=name) -def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable: +def rename_table( + old_name: str | Table, + new_name: str | Table, + dialect: DialectType = None, +) -> AlterTable: """Build ALTER TABLE... RENAME... expression Args: old_name: The old name of the table new_name: The new name of the table + dialect: The dialect to parse the table. Returns: Alter table expression """ - old_table = to_table(old_name) - new_table = to_table(new_name) + old_table = to_table(old_name, dialect=dialect) + new_table = to_table(new_name, dialect=dialect) return AlterTable( this=old_table, actions=[ @@ -6956,6 +7011,7 @@ def rename_column( old_column_name: str | Column, new_column_name: str | Column, exists: t.Optional[bool] = None, + dialect: DialectType = None, ) -> AlterTable: """Build ALTER TABLE... RENAME COLUMN... expression @@ -6964,13 +7020,14 @@ def rename_column( old_column: The old name of the column new_column: The new name of the column exists: Whether to add the `IF EXISTS` clause + dialect: The dialect to parse the table/column. Returns: Alter table expression """ - table = to_table(table_name) - old_column = to_column(old_column_name) - new_column = to_column(new_column_name) + table = to_table(table_name, dialect=dialect) + old_column = to_column(old_column_name, dialect=dialect) + new_column = to_column(new_column_name, dialect=dialect) return AlterTable( this=table, actions=[ @@ -7366,27 +7423,6 @@ def case( return Case(this=this, ifs=[]) -def cast_unless( - expression: ExpOrStr, - to: DATA_TYPE, - *types: DATA_TYPE, - **opts: t.Any, -) -> Expression | Cast: - """ - Cast an expression to a data type unless it is a specified type. - - Args: - expression: The expression to cast. - to: The data type to cast to. - **types: The types to exclude from casting. - **opts: Extra keyword arguments for parsing `expression` - """ - expr = maybe_parse(expression, **opts) - if expr.is_type(*types): - return expr - return cast(expr, to, **opts) - - def array( *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs ) -> Array: -- cgit v1.2.3