From ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 7 Mar 2023 19:09:31 +0100 Subject: Merging upstream version 11.3.0. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dataframe/sql/column.py | 8 +- sqlglot/dataframe/sql/dataframe.py | 2 +- sqlglot/dataframe/sql/functions.py | 8 +- sqlglot/dialects/bigquery.py | 4 + sqlglot/dialects/clickhouse.py | 3 + sqlglot/dialects/dialect.py | 5 +- sqlglot/dialects/duckdb.py | 6 + sqlglot/dialects/hive.py | 4 + sqlglot/dialects/mysql.py | 3 + sqlglot/dialects/oracle.py | 17 ++- sqlglot/dialects/postgres.py | 7 +- sqlglot/dialects/snowflake.py | 41 ++++++- sqlglot/dialects/teradata.py | 19 +++- sqlglot/expressions.py | 38 ++++++- sqlglot/generator.py | 81 +++++++++----- sqlglot/optimizer/merge_subqueries.py | 24 +++- sqlglot/optimizer/pushdown_projections.py | 6 +- sqlglot/parser.py | 179 +++++++++++++++++++----------- sqlglot/tokens.py | 8 ++ 20 files changed, 339 insertions(+), 126 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 87b36b0..d026627 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -47,7 +47,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.2.3" +__version__ = "11.3.0" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 609b2a4..f45d467 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -67,10 +67,10 @@ class Column: return self.binary_op(exp.Mul, other) def __truediv__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) + return self.binary_op(exp.FloatDiv, other) def __div__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) + return self.binary_op(exp.FloatDiv, other) def __neg__(self) -> Column: return self.unary_op(exp.Neg) @@ -85,10 +85,10 @@ class Column: return self.inverse_binary_op(exp.Mul, other) def __rdiv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) + return self.inverse_binary_op(exp.FloatDiv, other) def __rtruediv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) + return self.inverse_binary_op(exp.FloatDiv, other) def __rmod__(self, other: ColumnOrLiteral) -> Column: return self.inverse_binary_op(exp.Mod, other) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 93ca45a..32ee927 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -260,7 +260,7 @@ class DataFrame: @classmethod def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: expression = item.expression if isinstance(item, DataFrame) else item - return [Column(x) for x in expression.find(exp.Select).expressions] + return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] @classmethod def _create_hash_from_expression(cls, expression: exp.Select): diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 8f24746..3c98f42 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -954,10 +954,12 @@ def array_join( col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None ) -> Column: if null_replacement is not None: - return Column.invoke_anonymous_function( - col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement) + return Column.invoke_expression_over_column( + col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement) ) - return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter)) + return Column.invoke_expression_over_column( + col, expression.ArrayJoin, expression=lit(delimiter) + ) def concat(*cols: ColumnOrName) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 32b5147..a3869c6 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -213,7 +213,11 @@ class BigQuery(Dialect): ), } + INTEGER_DIVISION = False + class Generator(generator.Generator): + INTEGER_DIVISION = False + TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b553df2..a78d4db 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -56,6 +56,8 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + INTEGER_DIVISION = False + def _parse_in( self, this: t.Optional[exp.Expression], is_global: bool = False ) -> exp.Expression: @@ -94,6 +96,7 @@ class ClickHouse(Dialect): class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") + INTEGER_DIVISION = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index af36256..6939705 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -360,10 +360,9 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: if has_schema and is_partitionable: expression = expression.copy() prop = expression.find(exp.PartitionedByProperty) - this = prop and prop.this - if prop and not isinstance(this, exp.Schema): + if prop and prop.this and not isinstance(prop.this, exp.Schema): schema = expression.this - columns = {v.name.upper() for v in this.expressions} + columns = {v.name.upper() for v in prop.this.expressions} partitions = [col for col in schema.expressions if col.name.upper() in columns] schema.set("expressions", [e for e in schema.expressions if e not in partitions]) prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6144101..c2755cd 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -83,6 +83,7 @@ class DuckDB(Dialect): ":=": TokenType.EQ, "ATTACH": TokenType.COMMAND, "CHARACTER VARYING": TokenType.VARCHAR, + "EXCLUDE": TokenType.EXCEPT, } class Parser(parser.Parser): @@ -173,3 +174,8 @@ class DuckDB(Dialect): exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } + + STAR_MAPPING = { + **generator.Generator.STAR_MAPPING, + "except": "EXCLUDE", + } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index ea1191e..44cd875 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -256,7 +256,11 @@ class Hive(Dialect): ), } + INTEGER_DIVISION = False + class Generator(generator.Generator): + INTEGER_DIVISION = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 836bf3c..b1e20bd 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -300,6 +300,8 @@ class MySQL(Dialect): "READ ONLY", } + INTEGER_DIVISION = False + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -432,6 +434,7 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False + INTEGER_DIVISION = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 74baa8a..795bbeb 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -82,8 +82,17 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } + INTEGER_DIVISION = False + + def _parse_column(self) -> t.Optional[exp.Expression]: + column = super()._parse_column() + if column: + column.set("join_mark", self._match(TokenType.JOIN_MARKER)) + return column + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + INTEGER_DIVISION = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -108,6 +117,8 @@ class Oracle(Dialect): exp.Trim: trim_sql, exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), + exp.Table: lambda self, e: self.table_sql(e, sep=" "), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", exp.Substring: rename_func("SUBSTR"), @@ -139,8 +150,9 @@ class Oracle(Dialect): def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" - def table_sql(self, expression: exp.Table, sep: str = " ") -> str: - return super().table_sql(expression, sep=sep) + def column_sql(self, expression: exp.Column) -> str: + column = super().column_sql(expression) + return f"{column} (+)" if expression.args.get("join_mark") else column def xmltable_sql(self, expression: exp.XMLTable) -> str: this = self.sql(expression, "this") @@ -156,6 +168,7 @@ class Oracle(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "(+)": TokenType.JOIN_MARKER, "COLUMNS": TokenType.COLUMN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 3507cb5..35076db 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -222,10 +222,8 @@ class Postgres(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "CHARACTER VARYING": TokenType.VARCHAR, - "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, - "GRANT": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, "JSONB": TokenType.JSONB, "REFRESH": TokenType.COMMAND, @@ -260,10 +258,7 @@ class Postgres(Dialect): TokenType.HASH: exp.BitwiseXor, } - FACTOR = { - **parser.Parser.FACTOR, # type: ignore - TokenType.CARET: exp.Pow, - } + FACTOR = {**parser.Parser.FACTOR, TokenType.CARET: exp.Pow} class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5931364..4a090c2 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, @@ -104,6 +106,20 @@ def _parse_date_part(self): return self.expression(exp.Extract, this=this, expression=expression) +# https://docs.snowflake.com/en/sql-reference/functions/div0 +def _div0_to_if(args): + cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) + true = exp.Literal.number(0) + false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1)) + return exp.If(this=cond, true=true, false=false) + + +# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull +def _zeroifnull_to_if(args): + cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null()) + return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) + + def _datatype_sql(self, expression): if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" @@ -150,16 +166,20 @@ class Snowflake(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore this=seq_get(args, 1), ), + "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "TO_ARRAY": exp.Array.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "DECODE": exp.Matches.from_arg_list, "OBJECT_CONSTRUCT": parser.parse_var_map, + "ZEROIFNULL": _zeroifnull_to_if, } FUNCTION_PARSERS = { @@ -193,6 +213,19 @@ class Snowflake(Dialect): ), } + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, # type: ignore + "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), + "SET": lambda self: self._parse_alter_table_set_tag(), + } + + INTEGER_DIVISION = False + + def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: + self._match_text_seq("TAG") + parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction) + return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset) + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] @@ -220,12 +253,14 @@ class Snowflake(Dialect): class Generator(generator.Generator): PARAMETER_TOKEN = "$" + INTEGER_DIVISION = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), - exp.DateAdd: rename_func("DATEADD"), + exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), + exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -294,6 +329,10 @@ class Snowflake(Dialect): return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) return super().values_sql(expression) + def settag_sql(self, expression: exp.SetTag) -> str: + action = "UNSET" if expression.args.get("unset") else "SET" + return f"{action} TAG {self.expressions(expression)}" + def select_sql(self, expression: exp.Select) -> str: """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 7953bc5..415681c 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -74,6 +74,7 @@ class Teradata(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, # type: ignore + "RANGE_N": lambda self: self._parse_rangen(), "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), } @@ -105,6 +106,15 @@ class Teradata(Dialect): }, ) + def _parse_rangen(self): + this = self._parse_id_var() + self._match(TokenType.BETWEEN) + + expressions = self._parse_csv(self._parse_conjunction) + each = self._match_text_seq("EACH") and self._parse_conjunction() + + return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -114,7 +124,6 @@ class Teradata(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, - exp.VolatilityProperty: exp.Properties.Location.POST_CREATE, } def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: @@ -137,3 +146,11 @@ class Teradata(Dialect): type_sql = super().datatype_sql(expression) prefix_sql = expression.args.get("prefix") return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql + + def rangen_sql(self, expression: exp.RangeN) -> str: + this = self.sql(expression, "this") + expressions_sql = self.expressions(expression) + each_sql = self.sql(expression, "each") + each_sql = f" EACH {each_sql}" if each_sql else "" + + return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 59881d6..00a3b45 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -35,6 +35,8 @@ from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType +E = t.TypeVar("E", bound="Expression") + class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -293,7 +295,7 @@ class Expression(metaclass=_Expression): return self.parent.depth + 1 return 0 - def find(self, *expression_types, bfs=True): + def find(self, *expression_types: t.Type[E], bfs=True) -> E | None: """ Returns the first node in this tree which matches at least one of the specified types. @@ -306,7 +308,7 @@ class Expression(metaclass=_Expression): """ return next(self.find_all(*expression_types, bfs=bfs), None) - def find_all(self, *expression_types, bfs=True): + def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]: """ Returns a generator object which visits all nodes in this tree and only yields those that match at least one of the specified expression types. @@ -321,7 +323,7 @@ class Expression(metaclass=_Expression): if isinstance(expression, expression_types): yield expression - def find_ancestor(self, *expression_types): + def find_ancestor(self, *expression_types: t.Type[E]) -> E | None: """ Returns a nearest parent matching expression_types. @@ -334,7 +336,8 @@ class Expression(metaclass=_Expression): ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): ancestor = ancestor.parent - return ancestor + # ignore type because mypy doesn't know that we're checking type in the loop + return ancestor # type: ignore[return-value] @property def parent_select(self): @@ -794,6 +797,7 @@ class Create(Expression): "properties": False, "replace": False, "unique": False, + "volatile": False, "indexes": False, "no_schema_binding": False, "begin": False, @@ -883,7 +887,7 @@ class ByteString(Condition): class Column(Condition): - arg_types = {"this": True, "table": False, "db": False, "catalog": False} + arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False} @property def table(self) -> str: @@ -926,6 +930,14 @@ class RenameTable(Expression): pass +class SetTag(Expression): + arg_types = {"expressions": True, "unset": False} + + +class Comment(Expression): + arg_types = {"this": True, "kind": True, "expression": True, "exists": False} + + class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} @@ -2829,6 +2841,14 @@ class Div(Binary): pass +class FloatDiv(Binary): + pass + + +class Overlaps(Binary): + pass + + class Dot(Binary): @property def name(self) -> str: @@ -3125,6 +3145,10 @@ class ArrayFilter(Func): _sql_names = ["FILTER", "ARRAY_FILTER"] +class ArrayJoin(Func): + arg_types = {"this": True, "expression": True, "null": False} + + class ArraySize(Func): arg_types = {"this": True, "expression": False} @@ -3510,6 +3534,10 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} +class RangeN(Func): + arg_types = {"this": True, "expressions": True, "each": False} + + class ReadCSV(Func): _sql_names = ["READ_CSV"] is_var_len_args = True diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0a7a81f..79501ef 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -109,6 +109,9 @@ class Generator: # Whether or not create function uses an AS before the RETURN CREATE_FUNCTION_RETURN_AS = True + # Whether or not to treat the division operator "/" as integer division + INTEGER_DIVISION = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -550,14 +553,17 @@ class Generator: else: expression_sql = f" AS{expression_sql}" - replace = " OR REPLACE" if expression.args.get("replace") else "" - unique = " UNIQUE" if expression.args.get("unique") else "" - exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" + postindex_props_sql = "" + if properties_locs.get(exp.Properties.Location.POST_INDEX): + postindex_props_sql = self.properties( + exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_INDEX]), + wrapped=False, + prefix=" ", + ) indexes = expression.args.get("indexes") - index_sql = "" if indexes: - indexes_sql = [] + indexes_sql: t.List[str] = [] for index in indexes: ind_unique = " UNIQUE" if index.args.get("unique") else "" ind_primary = " PRIMARY" if index.args.get("primary") else "" @@ -568,21 +574,24 @@ class Generator: if index.args.get("columns") else "" ) - if index.args.get("primary") and properties_locs.get( - exp.Properties.Location.POST_INDEX - ): - postindex_props_sql = self.properties( - exp.Properties( - expressions=properties_locs[exp.Properties.Location.POST_INDEX] - ), - wrapped=False, + ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" + + if indexes_sql: + indexes_sql.append(ind_sql) + else: + indexes_sql.append( + f"{ind_sql}{postindex_props_sql}" + if index.args.get("primary") + else f"{postindex_props_sql}{ind_sql}" ) - ind_columns = f"{ind_columns} {postindex_props_sql}" - indexes_sql.append( - f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" - ) index_sql = "".join(indexes_sql) + else: + index_sql = postindex_props_sql + + replace = " OR REPLACE" if expression.args.get("replace") else "" + unique = " UNIQUE" if expression.args.get("unique") else "" + volatile = " VOLATILE" if expression.args.get("volatile") else "" postcreate_props_sql = "" if properties_locs.get(exp.Properties.Location.POST_CREATE): @@ -593,7 +602,7 @@ class Generator: wrapped=False, ) - modifiers = "".join((replace, unique, postcreate_props_sql)) + modifiers = "".join((replace, unique, volatile, postcreate_props_sql)) postexpression_props_sql = "" if properties_locs.get(exp.Properties.Location.POST_EXPRESSION): @@ -606,6 +615,7 @@ class Generator: wrapped=False, ) + exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" no_schema_binding = ( " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" ) @@ -1335,14 +1345,15 @@ class Generator: def placeholder_sql(self, expression: exp.Placeholder) -> str: return f":{expression.name}" if expression.name else "?" - def subquery_sql(self, expression: exp.Subquery) -> str: + def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: alias = self.sql(expression, "alias") + alias = f"{sep}{alias}" if alias else "" sql = self.query_modifiers( expression, self.wrap(expression), self.expressions(expression, key="pivots", sep=" "), - f" AS {alias}" if alias else "", + alias, ) return self.prepend_ctes(expression, sql) @@ -1643,6 +1654,13 @@ class Generator: def command_sql(self, expression: exp.Command) -> str: return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" + def comment_sql(self, expression: exp.Comment) -> str: + this = self.sql(expression, "this") + kind = expression.args["kind"] + exists_sql = " IF EXISTS " if expression.args.get("exists") else " " + expression_sql = self.sql(expression, "expression") + return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}" + def transaction_sql(self, *_) -> str: return "BEGIN" @@ -1728,19 +1746,30 @@ class Generator: return f"{self.sql(expression, 'this')} RESPECT NULLS" def intdiv_sql(self, expression: exp.IntDiv) -> str: - return self.sql( - exp.Cast( - this=exp.Div(this=expression.this, expression=expression.expression), - to=exp.DataType(this=exp.DataType.Type.INT), - ) - ) + div = self.binary(expression, "/") + return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT"))) def dpipe_sql(self, expression: exp.DPipe) -> str: return self.binary(expression, "||") def div_sql(self, expression: exp.Div) -> str: + div = self.binary(expression, "/") + + if not self.INTEGER_DIVISION: + return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT"))) + + return div + + def floatdiv_sql(self, expression: exp.FloatDiv) -> str: + if self.INTEGER_DIVISION: + this = exp.Cast(this=expression.this, to=exp.DataType.build("DOUBLE")) + return self.div_sql(exp.Div(this=this, expression=expression.expression)) + return self.binary(expression, "/") + def overlaps_sql(self, expression: exp.Overlaps) -> str: + return self.binary(expression, "OVERLAPS") + def distance_sql(self, expression: exp.Distance) -> str: return self.binary(expression, "<->") diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 16aaf17..70172f4 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -314,13 +314,27 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if not where or not where.this: return + expression = outer_scope.expression + if isinstance(from_or_join, exp.Join): # Merge predicates from an outer join to the ON clause - from_or_join.on(where.this, copy=False) - from_or_join.set("on", simplify(from_or_join.args.get("on"))) - else: - outer_scope.expression.where(where.this, copy=False) - outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where"))) + # if it only has columns that are already joined + from_ = expression.args.get("from") + sources = {table.alias_or_name for table in from_.expressions} if from_ else {} + + for join in expression.args["joins"]: + source = join.alias_or_name + sources.add(source) + if source == from_or_join.alias_or_name: + break + + if set(exp.column_table_names(where.this)) <= sources: + from_or_join.on(where.this, copy=False) + from_or_join.set("on", simplify(from_or_join.args.get("on"))) + return + + expression.where(where.this, copy=False) + expression.set("where", simplify(expression.args.get("where"))) def _merge_order(outer_scope, inner_scope): diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 3f360f9..07a1b70 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -13,7 +13,7 @@ SELECT_ALL = object() DEFAULT_SELECTION = lambda: alias("1", "_") -def pushdown_projections(expression, schema=None): +def pushdown_projections(expression, schema=None, remove_unused_selections=True): """ Rewrite sqlglot AST to remove unused columns projections. @@ -26,6 +26,7 @@ def pushdown_projections(expression, schema=None): Args: expression (sqlglot.Expression): expression to optimize + remove_unused_selections (bool): remove selects that are unused Returns: sqlglot.Expression: optimized expression """ @@ -57,7 +58,8 @@ def pushdown_projections(expression, schema=None): ] if isinstance(scope.expression, exp.Select): - _remove_unused_selections(scope, parent_selections, schema) + if remove_unused_selections: + _remove_unused_selections(scope, parent_selections, schema) # Group columns by source name selects = defaultdict(set) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 9f32765..f39bb39 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -36,6 +36,10 @@ class _Parser(type): klass = super().__new__(cls, clsname, bases, attrs) klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS) + + if not klass.INTEGER_DIVISION: + klass.FACTOR = {**klass.FACTOR, TokenType.SLASH: exp.FloatDiv} + return klass @@ -157,6 +161,21 @@ class Parser(metaclass=_Parser): RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT} + DB_CREATABLES = { + TokenType.DATABASE, + TokenType.SCHEMA, + TokenType.TABLE, + TokenType.VIEW, + } + + CREATABLES = { + TokenType.COLUMN, + TokenType.FUNCTION, + TokenType.INDEX, + TokenType.PROCEDURE, + *DB_CREATABLES, + } + ID_VAR_TOKENS = { TokenType.VAR, TokenType.ANTI, @@ -168,8 +187,8 @@ class Parser(metaclass=_Parser): TokenType.CACHE, TokenType.CASCADE, TokenType.COLLATE, - TokenType.COLUMN, TokenType.COMMAND, + TokenType.COMMENT, TokenType.COMMIT, TokenType.COMPOUND, TokenType.CONSTRAINT, @@ -186,9 +205,7 @@ class Parser(metaclass=_Parser): TokenType.FILTER, TokenType.FOLLOWING, TokenType.FORMAT, - TokenType.FUNCTION, TokenType.IF, - TokenType.INDEX, TokenType.ISNULL, TokenType.INTERVAL, TokenType.LAZY, @@ -211,13 +228,11 @@ class Parser(metaclass=_Parser): TokenType.RIGHT, TokenType.ROW, TokenType.ROWS, - TokenType.SCHEMA, TokenType.SEED, TokenType.SEMI, TokenType.SET, TokenType.SHOW, TokenType.SORTKEY, - TokenType.TABLE, TokenType.TEMPORARY, TokenType.TOP, TokenType.TRAILING, @@ -226,10 +241,9 @@ class Parser(metaclass=_Parser): TokenType.UNIQUE, TokenType.UNLOGGED, TokenType.UNPIVOT, - TokenType.PROCEDURE, - TokenType.VIEW, TokenType.VOLATILE, TokenType.WINDOW, + *CREATABLES, *SUBQUERY_PREDICATES, *TYPE_TOKENS, *NO_PAREN_FUNCTIONS, @@ -428,6 +442,7 @@ class Parser(metaclass=_Parser): TokenType.BEGIN: lambda self: self._parse_transaction(), TokenType.CACHE: lambda self: self._parse_cache(), TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), + TokenType.COMMENT: lambda self: self._parse_comment(), TokenType.CREATE: lambda self: self._parse_create(), TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DESC: lambda self: self._parse_describe(), @@ -490,6 +505,9 @@ class Parser(metaclass=_Parser): TokenType.GLOB: lambda self, this: self._parse_escape( self.expression(exp.Glob, this=this, expression=self._parse_bitwise()) ), + TokenType.OVERLAPS: lambda self, this: self._parse_escape( + self.expression(exp.Overlaps, this=this, expression=self._parse_bitwise()) + ), TokenType.IN: lambda self, this: self._parse_in(this), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: lambda self, this: self._parse_escape( @@ -628,6 +646,14 @@ class Parser(metaclass=_Parser): "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), } + ALTER_PARSERS = { + "ADD": lambda self: self._parse_alter_table_add(), + "ALTER": lambda self: self._parse_alter_table_alter(), + "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()), + "DROP": lambda self: self._parse_alter_table_drop(), + "RENAME": lambda self: self._parse_alter_table_rename(), + } + SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} NO_PAREN_FUNCTION_PARSERS = { @@ -669,16 +695,6 @@ class Parser(metaclass=_Parser): MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) - CREATABLES = { - TokenType.COLUMN, - TokenType.FUNCTION, - TokenType.INDEX, - TokenType.PROCEDURE, - TokenType.SCHEMA, - TokenType.TABLE, - TokenType.VIEW, - } - TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} @@ -689,6 +705,8 @@ class Parser(metaclass=_Parser): STRICT_CAST = True + INTEGER_DIVISION = True + __slots__ = ( "error_level", "error_message_context", @@ -940,6 +958,32 @@ class Parser(metaclass=_Parser): def _parse_command(self) -> exp.Expression: return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) + def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: + start = self._prev + exists = self._parse_exists() if allow_exists else None + + self._match(TokenType.ON) + + kind = self._match_set(self.CREATABLES) and self._prev + + if not kind: + return self._parse_as_command(start) + + if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): + this = self._parse_user_defined_function(kind=kind.token_type) + elif kind.token_type == TokenType.TABLE: + this = self._parse_table() + elif kind.token_type == TokenType.COLUMN: + this = self._parse_column() + else: + this = self._parse_id_var() + + self._match(TokenType.IS) + + return self.expression( + exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists + ) + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -990,6 +1034,7 @@ class Parser(metaclass=_Parser): TokenType.OR, TokenType.REPLACE ) unique = self._match(TokenType.UNIQUE) + volatile = self._match(TokenType.VOLATILE) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): self._match(TokenType.TABLE) @@ -1028,11 +1073,7 @@ class Parser(metaclass=_Parser): expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: this = self._parse_index() - elif create_token.token_type in ( - TokenType.TABLE, - TokenType.VIEW, - TokenType.SCHEMA, - ): + elif create_token.token_type in self.DB_CREATABLES: table_parts = self._parse_table_parts(schema=True) # exp.Properties.Location.POST_NAME @@ -1100,11 +1141,12 @@ class Parser(metaclass=_Parser): exp.Create, this=this, kind=create_token.text, + replace=replace, unique=unique, + volatile=volatile, expression=expression, exists=exists, properties=properties, - replace=replace, indexes=indexes, no_schema_binding=no_schema_binding, begin=begin, @@ -3648,6 +3690,47 @@ class Parser(metaclass=_Parser): return self.expression(exp.AddConstraint, this=this, expression=expression) + def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]: + index = self._index - 1 + + if self._match_set(self.ADD_CONSTRAINT_TOKENS): + return self._parse_csv(self._parse_add_constraint) + + self._retreat(index) + return self._parse_csv(self._parse_add_column) + + def _parse_alter_table_alter(self) -> exp.Expression: + self._match(TokenType.COLUMN) + column = self._parse_field(any_token=True) + + if self._match_pair(TokenType.DROP, TokenType.DEFAULT): + return self.expression(exp.AlterColumn, this=column, drop=True) + if self._match_pair(TokenType.SET, TokenType.DEFAULT): + return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction()) + + self._match_text_seq("SET", "DATA") + return self.expression( + exp.AlterColumn, + this=column, + dtype=self._match_text_seq("TYPE") and self._parse_types(), + collate=self._match(TokenType.COLLATE) and self._parse_term(), + using=self._match(TokenType.USING) and self._parse_conjunction(), + ) + + def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]: + index = self._index - 1 + + partition_exists = self._parse_exists() + if self._match(TokenType.PARTITION, advance=False): + return self._parse_csv(lambda: self._parse_drop_partition(exists=partition_exists)) + + self._retreat(index) + return self._parse_csv(self._parse_drop_column) + + def _parse_alter_table_rename(self) -> exp.Expression: + self._match_text_seq("TO") + return self.expression(exp.RenameTable, this=self._parse_table(schema=True)) + def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): return self._parse_as_command(self._prev) @@ -3655,50 +3738,12 @@ class Parser(metaclass=_Parser): exists = self._parse_exists() this = self._parse_table(schema=True) - actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None - - index = self._index - if self._match(TokenType.DELETE): - actions = [self.expression(exp.Delete, where=self._parse_where())] - elif self._match_text_seq("ADD"): - if self._match_set(self.ADD_CONSTRAINT_TOKENS): - actions = self._parse_csv(self._parse_add_constraint) - else: - self._retreat(index) - actions = self._parse_csv(self._parse_add_column) - elif self._match_text_seq("DROP"): - partition_exists = self._parse_exists() + if not self._curr: + return None - if self._match(TokenType.PARTITION, advance=False): - actions = self._parse_csv( - lambda: self._parse_drop_partition(exists=partition_exists) - ) - else: - self._retreat(index) - actions = self._parse_csv(self._parse_drop_column) - elif self._match_text_seq("RENAME", "TO"): - actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True)) - elif self._match_text_seq("ALTER"): - self._match(TokenType.COLUMN) - column = self._parse_field(any_token=True) - - if self._match_pair(TokenType.DROP, TokenType.DEFAULT): - actions = self.expression(exp.AlterColumn, this=column, drop=True) - elif self._match_pair(TokenType.SET, TokenType.DEFAULT): - actions = self.expression( - exp.AlterColumn, this=column, default=self._parse_conjunction() - ) - else: - self._match_text_seq("SET", "DATA") - actions = self.expression( - exp.AlterColumn, - this=column, - dtype=self._match_text_seq("TYPE") and self._parse_types(), - collate=self._match(TokenType.COLLATE) and self._parse_term(), - using=self._match(TokenType.USING) and self._parse_conjunction(), - ) + parser = self.ALTER_PARSERS.get(self._curr.text.upper()) + actions = ensure_list(self._advance() or parser(self)) if parser else [] # type: ignore - actions = ensure_list(actions) return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) def _parse_show(self) -> t.Optional[exp.Expression]: @@ -3772,7 +3817,9 @@ class Parser(metaclass=_Parser): def _parse_as_command(self, start: Token) -> exp.Command: while self._curr: self._advance() - return exp.Command(this=self._find_sql(start, self._prev)) + text = self._find_sql(start, self._prev) + size = len(start.text) + return exp.Command(this=text[:size], expression=text[size:]) def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index f3f1a70..7a23803 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -60,6 +60,7 @@ class TokenType(AutoName): STRING = auto() NUMBER = auto() IDENTIFIER = auto() + DATABASE = auto() COLUMN = auto() COLUMN_DEF = auto() SCHEMA = auto() @@ -203,6 +204,7 @@ class TokenType(AutoName): IS = auto() ISNULL = auto() JOIN = auto() + JOIN_MARKER = auto() LANGUAGE = auto() LATERAL = auto() LAZY = auto() @@ -235,6 +237,7 @@ class TokenType(AutoName): OUTER = auto() OUT_OF = auto() OVER = auto() + OVERLAPS = auto() OVERWRITE = auto() PARTITION = auto() PARTITION_BY = auto() @@ -491,6 +494,7 @@ class Tokenizer(metaclass=_Tokenizer): "CURRENT_DATE": TokenType.CURRENT_DATE, "CURRENT ROW": TokenType.CURRENT_ROW, "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, + "DATABASE": TokenType.DATABASE, "DEFAULT": TokenType.DEFAULT, "DELETE": TokenType.DELETE, "DESC": TokenType.DESC, @@ -564,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer): "OUTER": TokenType.OUTER, "OUT OF": TokenType.OUT_OF, "OVER": TokenType.OVER, + "OVERLAPS": TokenType.OVERLAPS, "OVERWRITE": TokenType.OVERWRITE, "PARTITION": TokenType.PARTITION, "PARTITION BY": TokenType.PARTITION_BY, @@ -652,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer): "DOUBLE PRECISION": TokenType.DOUBLE, "JSON": TokenType.JSON, "CHAR": TokenType.CHAR, + "CHARACTER": TokenType.CHAR, "NCHAR": TokenType.NCHAR, "VARCHAR": TokenType.VARCHAR, "VARCHAR2": TokenType.VARCHAR, @@ -687,8 +693,10 @@ class Tokenizer(metaclass=_Tokenizer): "ALTER VIEW": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND, "CALL": TokenType.COMMAND, + "COMMENT": TokenType.COMMENT, "COPY": TokenType.COMMAND, "EXPLAIN": TokenType.COMMAND, + "GRANT": TokenType.COMMAND, "OPTIMIZE": TokenType.COMMAND, "PREPARE": TokenType.COMMAND, "TRUNCATE": TokenType.COMMAND, -- cgit v1.2.3