diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:11:53 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-08 08:12:02 +0000 |
commit | 8d36f5966675e23bee7026ba37ae0647fbf47300 (patch) | |
tree | df4227bbb3b07cb70df87237bcff03c8efd7822d /sqlglot | |
parent | Releasing debian version 22.2.0-1. (diff) | |
download | sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.tar.xz sqlglot-8d36f5966675e23bee7026ba37ae0647fbf47300.zip |
Merging upstream version 23.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
54 files changed, 2549 insertions, 1070 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index e30232c..756532f 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -45,7 +45,7 @@ from sqlglot.expressions import ( from sqlglot.generator import Generator as Generator from sqlglot.parser import Parser as Parser from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema -from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType +from sqlglot.tokens import Token as Token, Tokenizer as Tokenizer, TokenType as TokenType if t.TYPE_CHECKING: from sqlglot._typing import E @@ -69,6 +69,21 @@ schema = MappingSchema() """The default schema used by SQLGlot (e.g. in the optimizer).""" +def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]: + """ + Tokenizes the given SQL string. + + Args: + sql: the SQL code string to tokenize. + read: the SQL dialect to apply during tokenizing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read). + + Returns: + The resulting list of tokens. + """ + return Dialect.get_or_raise(read or dialect).tokenize(sql) + + def parse( sql: str, read: DialectType = None, dialect: DialectType = None, **opts ) -> t.List[t.Optional[Expression]]: diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 0bacbf9..8316c36 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -18,8 +18,6 @@ from sqlglot.dataframe.sql.transforms import replace_id_value from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window from sqlglot.helper import ensure_list, object_to_dict, seq_get -from sqlglot.optimizer import optimize as optimize_func -from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ( @@ -121,7 +119,9 @@ class DataFrame: self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] ) replacement_mapping[old_name_id] = new_hashed_id - expression = expression.transform(replace_id_value, replacement_mapping) + expression = expression.transform(replace_id_value, replacement_mapping).assert_is( + exp.Select + ) return expression def _create_cte_from_expression( @@ -306,11 +306,12 @@ class DataFrame: replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} for expression_type, select_expression in select_expressions: - select_expression = select_expression.transform(replace_id_value, replacement_mapping) + select_expression = select_expression.transform( + replace_id_value, replacement_mapping + ).assert_is(exp.Select) if optimize: - quote_identifiers(select_expression, dialect=dialect) select_expression = t.cast( - exp.Select, optimize_func(select_expression, dialect=dialect) + exp.Select, self.spark._optimize(select_expression, dialect=dialect) ) select_expression = df._replace_cte_names_with_hashes(select_expression) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index db5201f..b4dd2c6 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -184,7 +184,7 @@ def floor(col: ColumnOrName) -> Column: def log10(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Log10) + return Column.invoke_expression_over_column(lit(10), expression.Log, expression=col) def log1p(col: ColumnOrName) -> Column: @@ -192,7 +192,7 @@ def log1p(col: ColumnOrName) -> Column: def log2(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Log2) + return Column.invoke_expression_over_column(lit(2), expression.Log, expression=col) def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: @@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column: def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "CORR", col2) + return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2) def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_POP", col2) + return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2) def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2) + return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2) def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: @@ -971,10 +971,10 @@ def array_join( ) -> Column: if null_replacement is not None: return Column.invoke_expression_over_column( - col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement) + col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement) ) return Column.invoke_expression_over_column( - col, expression.ArrayJoin, expression=lit(delimiter) + col, expression.ArrayToString, expression=lit(delimiter) ) diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index bfc022b..4e47aaa 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -12,6 +12,8 @@ from sqlglot.dataframe.sql.readwriter import DataFrameReader from sqlglot.dataframe.sql.types import StructType from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input from sqlglot.helper import classproperty +from sqlglot.optimizer import optimize +from sqlglot.optimizer.qualify_columns import quote_identifiers if t.TYPE_CHECKING: from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput @@ -104,8 +106,15 @@ class SparkSession: sel_expression = exp.Select(**select_kwargs) return DataFrame(self, sel_expression) + def _optimize( + self, expression: exp.Expression, dialect: t.Optional[Dialect] = None + ) -> exp.Expression: + dialect = dialect or self.dialect + quote_identifiers(expression, dialect=dialect) + return optimize(expression, dialect=dialect) + def sql(self, sqlQuery: str) -> DataFrame: - expression = sqlglot.parse_one(sqlQuery, read=self.dialect) + expression = self._optimize(sqlglot.parse_one(sqlQuery, read=self.dialect)) if isinstance(expression, exp.Select): df = DataFrame(self, expression) df = df._convert_leaf_to_cte() diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 82552c9..29c6580 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -61,6 +61,7 @@ dialect implementations in order to understand how their various components can ---- """ +from sqlglot.dialects.athena import Athena from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks @@ -73,6 +74,7 @@ from sqlglot.dialects.mysql import MySQL from sqlglot.dialects.oracle import Oracle from sqlglot.dialects.postgres import Postgres from sqlglot.dialects.presto import Presto +from sqlglot.dialects.prql import PRQL from sqlglot.dialects.redshift import Redshift from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.spark import Spark diff --git a/sqlglot/dialects/athena.py b/sqlglot/dialects/athena.py new file mode 100644 index 0000000..f2deec8 --- /dev/null +++ b/sqlglot/dialects/athena.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.trino import Trino +from sqlglot.tokens import TokenType + + +class Athena(Trino): + class Parser(Trino.Parser): + STATEMENT_PARSERS = { + **Trino.Parser.STATEMENT_PARSERS, + TokenType.USING: lambda self: self._parse_as_command(self._prev), + } + + class Generator(Trino.Generator): + PROPERTIES_LOCATION = { + **Trino.Generator.PROPERTIES_LOCATION, + exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, + } + + TYPE_MAPPING = { + **Trino.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + } + + TRANSFORMS = { + **Trino.Generator.TRANSFORMS, + exp.FileFormatProperty: lambda self, e: f"'FORMAT'={self.sql(e, 'this')}", + } + + def property_sql(self, expression: exp.Property) -> str: + return ( + f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" + ) + + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, prefix=self.seg("TBLPROPERTIES")) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 5bfc3ea..2167ba2 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -24,6 +24,7 @@ from sqlglot.dialects.dialect import ( rename_func, timestrtotime_sql, ts_or_ds_add_cast, + unit_to_var, ) from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType @@ -41,14 +42,22 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va structs = [] alias = expression.args.get("alias") for tup in expression.find_all(exp.Tuple): - field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions))) + field_aliases = ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(tup.expressions))) + ) expressions = [ exp.PropertyEQ(this=exp.to_identifier(name), expression=fld) for name, fld in zip(field_aliases, tup.expressions) ] structs.append(exp.Struct(expressions=expressions)) - return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)])) + # Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression + alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None + return self.unnest_sql( + exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only) + ) def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: @@ -190,7 +199,7 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str: expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True)) expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True)) - unit = expression.args.get("unit") or "DAY" + unit = unit_to_var(expression) return self.func("DATE_DIFF", expression.this, expression.expression, unit) @@ -238,16 +247,6 @@ class BigQuery(Dialect): "%E6S": "%S.%f", } - ESCAPE_SEQUENCES = { - "\\a": "\a", - "\\b": "\b", - "\\f": "\f", - "\\n": "\n", - "\\r": "\r", - "\\t": "\t", - "\\v": "\v", - } - FORMAT_MAPPING = { "DD": "%d", "MM": "%m", @@ -315,6 +314,7 @@ class BigQuery(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "BYTES": TokenType.BINARY, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "DATETIME": TokenType.TIMESTAMP, "DECLARE": TokenType.COMMAND, "ELSEIF": TokenType.COMMAND, "EXCEPTION": TokenType.COMMAND, @@ -486,14 +486,14 @@ class BigQuery(Dialect): table.set("db", exp.Identifier(this=parts[0])) table.set("this", exp.Identifier(this=parts[1])) - if isinstance(table.this, exp.Identifier) and "." in table.name: + if any("." in p.name for p in table.parts): catalog, db, this, *rest = ( - t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True)) - for x in split_num_words(table.name, ".", 3) + exp.to_identifier(p, quoted=True) + for p in split_num_words(".".join(p.name for p in table.parts), ".", 3) ) if rest and this: - this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest])) + this = exp.Dot.build([this, *rest]) # type: ignore table = exp.Table(this=this, db=db, catalog=catalog) table.meta["quoted_table"] = True @@ -527,7 +527,9 @@ class BigQuery(Dialect): return json_object - def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: bracket = super()._parse_bracket(this) if this is bracket: @@ -566,6 +568,7 @@ class BigQuery(Dialect): IGNORE_NULLS_IN_FUNC = True JSON_PATH_SINGLE_QUOTE_ESCAPE = True CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False NAMED_PLACEHOLDER_TOKEN = "@" TRANSFORMS = { @@ -588,7 +591,7 @@ class BigQuery(Dialect): exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: date_add_interval_sql("DATE", "ADD"), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", e.this, e.expression, e.unit or "DAY" + "DATE_DIFF", e.this, e.expression, unit_to_var(e) ), exp.DateFromParts: rename_func("DATE"), exp.DateStrToDate: datestrtodate_sql, @@ -607,6 +610,7 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), exp.Max: max_or_greatest, + exp.Mod: rename_func("MOD"), exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), exp.Min: min_or_least, @@ -847,10 +851,10 @@ class BigQuery(Dialect): return inline_array_sql(self, expression) def bracket_sql(self, expression: exp.Bracket) -> str: - this = self.sql(expression, "this") + this = expression.this expressions = expression.expressions - if len(expressions) == 1: + if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT): arg = expressions[0] if arg.type is None: from sqlglot.optimizer.annotate_types import annotate_types @@ -858,10 +862,10 @@ class BigQuery(Dialect): arg = annotate_types(arg) if arg.type and arg.type.this in exp.DataType.TEXT_TYPES: - # BQ doesn't support bracket syntax with string values - return f"{this}.{arg.name}" + # BQ doesn't support bracket syntax with string values for structs + return f"{self.sql(this)}.{arg.name}" - expressions_sql = ", ".join(self.sql(e) for e in expressions) + expressions_sql = self.expressions(expression, flat=True) offset = expression.args.get("offset") if offset == 0: @@ -874,7 +878,7 @@ class BigQuery(Dialect): if expression.args.get("safe"): expressions_sql = f"SAFE_{expressions_sql}" - return f"{this}[{expressions_sql}]" + return f"{self.sql(this)}[{expressions_sql}]" def in_unnest_op(self, expression: exp.Unnest) -> str: return self.sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 90167f6..631dc30 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -15,7 +15,6 @@ from sqlglot.dialects.dialect import ( rename_func, var_map_sql, ) -from sqlglot.errors import ParseError from sqlglot.helper import is_int, seq_get from sqlglot.tokens import Token, TokenType @@ -49,8 +48,9 @@ class ClickHouse(Dialect): NULL_ORDERING = "nulls_are_last" SUPPORTS_USER_DEFINED_TYPES = False SAFE_DIVISION = True + LOG_BASE_FIRST: t.Optional[bool] = None - ESCAPE_SEQUENCES = { + UNESCAPED_SEQUENCES = { "\\0": "\0", } @@ -105,6 +105,7 @@ class ClickHouse(Dialect): # * select x from t1 union all select x from t2 limit 1; # * select x from t1 union all (select x from t2 limit 1); MODIFIERS_ATTACHED_TO_UNION = False + INTERVAL_SPANS = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -260,6 +261,11 @@ class ClickHouse(Dialect): "ArgMax", ] + FUNC_TOKENS = { + *parser.Parser.FUNC_TOKENS, + TokenType.SET, + } + AGG_FUNC_MAPPING = ( lambda functions, suffixes: { f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions @@ -305,6 +311,10 @@ class ClickHouse(Dialect): TokenType.SETTINGS, } + ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - { + TokenType.FORMAT, + } + LOG_DEFAULTS_TO_LN = True QUERY_MODIFIER_PARSERS = { @@ -316,6 +326,17 @@ class ClickHouse(Dialect): TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()), } + CONSTRAINT_PARSERS = { + **parser.Parser.CONSTRAINT_PARSERS, + "INDEX": lambda self: self._parse_index_constraint(), + "CODEC": lambda self: self._parse_compress(), + } + + SCHEMA_UNNAMED_CONSTRAINTS = { + *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, + "INDEX", + } + def _parse_conjunction(self) -> t.Optional[exp.Expression]: this = super()._parse_conjunction() @@ -381,21 +402,20 @@ class ClickHouse(Dialect): # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ def _parse_cte(self) -> exp.CTE: - index = self._index - try: - # WITH <identifier> AS <subquery expression> - return super()._parse_cte() - except ParseError: - # WITH <expression> AS <identifier> - self._retreat(index) + # WITH <identifier> AS <subquery expression> + cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte) - return self.expression( + if not cte: + # WITH <expression> AS <identifier> + cte = self.expression( exp.CTE, this=self._parse_conjunction(), alias=self._parse_table_alias(), scalar=True, ) + return cte + def _parse_join_parts( self, ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: @@ -508,6 +528,27 @@ class ClickHouse(Dialect): self._retreat(index) return None + def _parse_index_constraint( + self, kind: t.Optional[str] = None + ) -> exp.IndexColumnConstraint: + # INDEX name1 expr TYPE type1(args) GRANULARITY value + this = self._parse_id_var() + expression = self._parse_conjunction() + + index_type = self._match_text_seq("TYPE") and ( + self._parse_function() or self._parse_var() + ) + + granularity = self._match_text_seq("GRANULARITY") and self._parse_term() + + return self.expression( + exp.IndexColumnConstraint, + this=this, + expression=expression, + index_type=index_type, + granularity=granularity, + ) + class Generator(generator.Generator): QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") @@ -517,6 +558,7 @@ class ClickHouse(Dialect): TABLESAMPLE_KEYWORDS = "SAMPLE" LAST_DAY_SUPPORTS_DATE_PART = False CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -585,6 +627,9 @@ class ClickHouse(Dialect): exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), exp.CountIf: rename_func("countIf"), + exp.CompressColumnConstraint: lambda self, + e: f"CODEC({self.expressions(e, key='this', flat=True)})", + exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}", exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), exp.DateAdd: date_delta_sql("DATE_ADD"), exp.DateDiff: date_delta_sql("DATE_DIFF"), @@ -737,3 +782,15 @@ class ClickHouse(Dialect): def prewhere_sql(self, expression: exp.PreWhere) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('PREWHERE')}{self.sep()}{this}" + + def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + expr = self.sql(expression, "expression") + expr = f" {expr}" if expr else "" + index_type = self.sql(expression, "index_type") + index_type = f" TYPE {index_type}" if index_type else "" + granularity = self.sql(expression, "granularity") + granularity = f" GRANULARITY {granularity}" if granularity else "" + + return f"INDEX{this}{expr}{index_type}{granularity}" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 599505c..81057c2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -31,6 +31,7 @@ class Dialects(str, Enum): DIALECT = "" + ATHENA = "athena" BIGQUERY = "bigquery" CLICKHOUSE = "clickhouse" DATABRICKS = "databricks" @@ -42,6 +43,7 @@ class Dialects(str, Enum): ORACLE = "oracle" POSTGRES = "postgres" PRESTO = "presto" + PRQL = "prql" REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SPARK = "spark" @@ -108,11 +110,18 @@ class _Dialect(type): klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) - klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} + base = seq_get(bases, 0) + base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) + base_parser = (getattr(base, "parser_class", Parser),) + base_generator = (getattr(base, "generator_class", Generator),) - klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) - klass.parser_class = getattr(klass, "Parser", Parser) - klass.generator_class = getattr(klass, "Generator", Generator) + klass.tokenizer_class = klass.__dict__.get( + "Tokenizer", type("Tokenizer", base_tokenizer, {}) + ) + klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) + klass.generator_class = klass.__dict__.get( + "Generator", type("Generator", base_generator, {}) + ) klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( @@ -134,9 +143,31 @@ class _Dialect(type): klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) + if "\\" in klass.tokenizer_class.STRING_ESCAPES: + klass.UNESCAPED_SEQUENCES = { + "\\a": "\a", + "\\b": "\b", + "\\f": "\f", + "\\n": "\n", + "\\r": "\r", + "\\t": "\t", + "\\v": "\v", + "\\\\": "\\", + **klass.UNESCAPED_SEQUENCES, + } + + klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} + if enum not in ("", "bigquery"): klass.generator_class.SELECT_KINDS = () + if enum not in ("", "databricks", "hive", "spark", "spark2"): + modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() + for modifier in ("cluster", "distribute", "sort"): + modifier_transforms.pop(modifier, None) + + klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms + if not klass.SUPPORTS_SEMI_ANTI_JOIN: klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { TokenType.ANTI, @@ -189,8 +220,11 @@ class Dialect(metaclass=_Dialect): False: Disables function name normalization. """ - LOG_BASE_FIRST = True - """Whether the base comes first in the `LOG` function.""" + LOG_BASE_FIRST: t.Optional[bool] = True + """ + Whether the base comes first in the `LOG` function. + Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) + """ NULL_ORDERING = "nulls_are_small" """ @@ -226,8 +260,8 @@ class Dialect(metaclass=_Dialect): If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. """ - ESCAPE_SEQUENCES: t.Dict[str, str] = {} - """Mapping of an unescaped escape sequence to the corresponding character.""" + UNESCAPED_SEQUENCES: t.Dict[str, str] = {} + """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" PSEUDOCOLUMNS: t.Set[str] = set() """ @@ -266,7 +300,7 @@ class Dialect(metaclass=_Dialect): INVERSE_TIME_MAPPING: t.Dict[str, str] = {} INVERSE_TIME_TRIE: t.Dict = {} - INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} + ESCAPED_SEQUENCES: t.Dict[str, str] = {} # Delimiters for string literals and identifiers QUOTE_START = "'" @@ -587,13 +621,21 @@ def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> return "" -def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: +def str_position_sql( + self: Generator, expression: exp.StrPosition, generate_instance: bool = False +) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") position = self.sql(expression, "position") + instance = expression.args.get("instance") if generate_instance else None + position_offset = "" + if position: - return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" - return f"STRPOS({this}, {substr})" + # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects + this = self.func("SUBSTR", this, position) + position_offset = f" + {position} - 1" + + return self.func("STRPOS", this, substr, instance) + position_offset def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: @@ -689,9 +731,7 @@ def build_date_delta_with_interval( if expression and expression.is_string: expression = exp.Literal.number(expression.this) - return expression_class( - this=args[0], expression=expression, unit=exp.Literal.string(interval.text("unit")) - ) + return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) return _builder @@ -710,18 +750,14 @@ def date_add_interval_sql( ) -> t.Callable[[Generator, exp.Expression], str]: def func(self: Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") - unit = expression.args.get("unit") - unit = exp.var(unit.name.upper() if unit else "DAY") - interval = exp.Interval(this=expression.expression, unit=unit) + interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) return f"{data_type}_{kind}({this}, {self.sql(interval)})" return func def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: - return self.func( - "DATE_TRUNC", exp.Literal.string(expression.text("unit").upper() or "DAY"), expression.this - ) + return self.func("DATE_TRUNC", unit_to_str(expression), expression.this) def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: @@ -956,7 +992,7 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return self.func( name, - exp.var(expression.text("unit").upper() or "DAY"), + unit_to_var(expression), expression.expression, expression.this, ) @@ -964,6 +1000,24 @@ def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE return _delta_sql +def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, exp.Placeholder): + return unit + if unit: + return exp.Literal.string(unit.name) + return exp.Literal.string(default) if default else None + + +def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: + unit = expression.args.get("unit") + + if isinstance(unit, (exp.Var, exp.Placeholder)): + return unit + return exp.Var(this=default) if default else None + + def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: trunc_curr_date = exp.func("date_trunc", "month", expression.this) plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") @@ -998,7 +1052,7 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: def build_json_extract_path( - expr_type: t.Type[F], zero_based_indexing: bool = True + expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False ) -> t.Callable[[t.List], F]: def _builder(args: t.List) -> F: segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] @@ -1018,7 +1072,11 @@ def build_json_extract_path( # This is done to avoid failing in the expression validator due to the arg count del args[2:] - return expr_type(this=seq_get(args, 0), expression=exp.JSONPath(expressions=segments)) + return expr_type( + this=seq_get(args, 0), + expression=exp.JSONPath(expressions=segments), + only_json_types=arrow_req_json_type, + ) return _builder @@ -1070,3 +1128,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s unnest = exp.Unnest(expressions=[expression.this]) filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) return self.sql(exp.Array(expressions=[filtered])) + + +def to_number_with_nls_param(self, expression: exp.ToNumber) -> str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("nlsparam"), + ) diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 9a84848..f4ec0e5 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( build_timestamp_trunc, rename_func, time_format, + unit_to_str, ) from sqlglot.dialects.mysql import MySQL @@ -27,7 +28,7 @@ class Doris(MySQL): } class Generator(MySQL.Generator): - CAST_MAPPING = {} + LAST_DAY_SUPPORTS_DATE_PART = False TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, @@ -36,8 +37,7 @@ class Doris(MySQL): exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } - LAST_DAY_SUPPORTS_DATE_PART = False - + CAST_MAPPING = {} TIMESTAMP_FUNC_TYPES = set() TRANSFORMS = { @@ -49,9 +49,7 @@ class Doris(MySQL): exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), exp.CurrentTimestamp: lambda self, _: self.func("NOW"), - exp.DateTrunc: lambda self, e: self.func( - "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" - ), + exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.Map: rename_func("ARRAY_MAP"), @@ -63,9 +61,7 @@ class Doris(MySQL): exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression), exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" - ), + exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)), exp.UnixToStr: lambda self, e: self.func( "FROM_UNIXTIME", e.this, time_format("doris")(self, e) ), diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index c1f6afa..0a00d92 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -1,6 +1,5 @@ from __future__ import annotations -import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( @@ -12,18 +11,10 @@ from sqlglot.dialects.dialect import ( str_position_sql, timestrtotime_sql, ) +from sqlglot.dialects.mysql import date_add_sql from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by -def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - this = self.sql(expression, "this") - unit = exp.var(expression.text("unit").upper() or "DAY") - return self.func(f"DATE_{kind}", this, exp.Interval(this=expression.expression, unit=unit)) - - return func - - def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) @@ -84,7 +75,6 @@ class Drill(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"), } @@ -124,9 +114,9 @@ class Drill(Dialect): exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), exp.Create: preprocess([move_schema_columns_to_partitioned_by]), - exp.DateAdd: _date_add_sql("ADD"), + exp.DateAdd: date_add_sql("ADD"), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("SUB"), + exp.DateSub: date_add_sql("SUB"), exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", exp.DiToDate: lambda self, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f74dc97..6a1d07a 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -26,6 +26,7 @@ from sqlglot.dialects.dialect import ( str_to_time_sql, timestamptrunc_sql, timestrtotime_sql, + unit_to_var, ) from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -33,15 +34,16 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") - unit = self.sql(expression, "unit").strip("'") or "DAY" - interval = self.sql(exp.Interval(this=expression.expression, unit=unit)) + interval = self.sql(exp.Interval(this=expression.expression, unit=unit_to_var(expression))) return f"CAST({this} AS {self.sql(expression.return_type)}) + {interval}" -def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_delta_sql( + self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub | exp.TimeAdd +) -> str: this = self.sql(expression, "this") - unit = self.sql(expression, "unit").strip("'") or "DAY" - op = "+" if isinstance(expression, exp.DateAdd) else "-" + unit = unit_to_var(expression) + op = "+" if isinstance(expression, (exp.DateAdd, exp.TimeAdd)) else "-" return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" @@ -186,6 +188,11 @@ class DuckDB(Dialect): return super().to_json_path(path) class Tokenizer(tokens.Tokenizer): + HEREDOC_STRINGS = ["$"] + + HEREDOC_TAG_IS_IDENTIFIER = True + HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "//": TokenType.DIV, @@ -199,6 +206,7 @@ class DuckDB(Dialect): "LOGICAL": TokenType.BOOLEAN, "ONLY": TokenType.ONLY, "PIVOT_WIDER": TokenType.PIVOT, + "POSITIONAL": TokenType.POSITIONAL, "SIGNED": TokenType.INT, "STRING": TokenType.VARCHAR, "UBIGINT": TokenType.UBIGINT, @@ -227,8 +235,8 @@ class DuckDB(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_HAS": exp.ArrayContains.from_arg_list, - "ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_REVERSE_SORT": _build_sort_array_desc, + "ARRAY_SORT": exp.SortArray.from_arg_list, "DATEDIFF": _build_date_diff, "DATE_DIFF": _build_date_diff, "DATE_TRUNC": date_trunc_to_time, @@ -285,6 +293,11 @@ class DuckDB(Dialect): FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("DECODE") + NO_PAREN_FUNCTION_PARSERS = { + **parser.Parser.NO_PAREN_FUNCTION_PARSERS, + "MAP": lambda self: self._parse_map(), + } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.SEMI, TokenType.ANTI, @@ -299,6 +312,13 @@ class DuckDB(Dialect): ), } + def _parse_map(self) -> exp.ToMap | exp.Map: + if self._match(TokenType.L_BRACE, advance=False): + return self.expression(exp.ToMap, this=self._parse_bracket()) + + args = self._parse_wrapped_csv(self._parse_conjunction) + return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1)) + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -345,6 +365,7 @@ class DuckDB(Dialect): SUPPORTS_CREATE_TABLE_LIKE = False MULTI_ARG_DISTINCT = False CAN_IMPLEMENT_ARRAY_ANY = True + SUPPORTS_TO_NUMBER = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -425,6 +446,7 @@ class DuckDB(Dialect): "EPOCH", self.func("STRPTIME", e.this, self.format_time(e)) ), exp.Struct: _struct_sql, + exp.TimeAdd: _date_delta_sql, exp.Timestamp: no_timestamp_sql, exp.TimestampDiff: lambda self, e: self.func( "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this @@ -478,7 +500,7 @@ class DuckDB(Dialect): STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} - UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Paren) + UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren) # DuckDB doesn't generally support CREATE TABLE .. properties # https://duckdb.org/docs/sql/statements/create_table.html @@ -569,3 +591,9 @@ class DuckDB(Dialect): return rename_func("RANGE")(self, expression) return super().generateseries_sql(expression) + + def bracket_sql(self, expression: exp.Bracket) -> str: + if isinstance(expression.this, exp.Array): + expression.this.replace(exp.paren(expression.this)) + + return super().bracket_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 55a9254..cc7debb 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -319,7 +319,9 @@ class Hive(Dialect): "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, "UNBASE64": exp.FromBase64.from_arg_list, - "UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True), + "UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)( + args or [exp.CurrentTimestamp()] + ), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } @@ -431,6 +433,7 @@ class Hive(Dialect): NVL2_SUPPORTED = False LAST_DAY_SUPPORTS_DATE_PART = False JSON_PATH_SINGLE_QUOTE_ESCAPE = True + SUPPORTS_TO_NUMBER = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Insert, @@ -472,7 +475,7 @@ class Hive(Dialect): exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), exp.ArrayConcat: rename_func("CONCAT"), - exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), + exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort_sql, exp.With: no_recursive_cte_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 6ebae1e..1d53346 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import ( build_date_delta_with_interval, rename_func, strposition_to_locate_sql, + unit_to_var, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql( +def date_add_sql( kind: str, -) -> t.Callable[[MySQL.Generator, exp.Expression], str]: - def func(self: MySQL.Generator, expression: exp.Expression) -> str: - this = self.sql(expression, "this") - unit = expression.text("unit").upper() or "DAY" - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" +) -> t.Callable[[generator.Generator, exp.Expression], str]: + def func(self: generator.Generator, expression: exp.Expression) -> str: + return self.func( + f"DATE_{kind}", + expression.this, + exp.Interval(this=expression.expression, unit=unit_to_var(expression)), ) return func @@ -291,6 +292,7 @@ class MySQL(Dialect): "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"), "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, "MAKETIME": exp.TimeFromParts.from_arg_list, @@ -319,11 +321,7 @@ class MySQL(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, "CHAR": lambda self: self._parse_chr(), - "GROUP_CONCAT": lambda self: self.expression( - exp.GroupConcat, - this=self._parse_lambda(), - separator=self._match(TokenType.SEPARATOR) and self._parse_field(), - ), + "GROUP_CONCAT": lambda self: self._parse_group_concat(), # https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values "VALUES": lambda self: self.expression( exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()] @@ -412,6 +410,11 @@ class MySQL(Dialect): "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"), } + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, + "MODIFY": lambda self: self._parse_alter_table_alter(), + } + SCHEMA_UNNAMED_CONSTRAINTS = { *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, "FULLTEXT", @@ -458,7 +461,7 @@ class MySQL(Dialect): this = self._parse_id_var(any_token=False) index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text - schema = self._parse_schema() + expressions = self._parse_wrapped_csv(self._parse_ordered) options = [] while True: @@ -478,9 +481,6 @@ class MySQL(Dialect): elif self._match_text_seq("ENGINE_ATTRIBUTE"): self._match(TokenType.EQ) opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) - elif self._match_text_seq("ENGINE_ATTRIBUTE"): - self._match(TokenType.EQ) - opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"): self._match(TokenType.EQ) opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string()) @@ -495,7 +495,7 @@ class MySQL(Dialect): return self.expression( exp.IndexColumnConstraint, this=this, - schema=schema, + expressions=expressions, kind=kind, index_type=index_type, options=options, @@ -617,6 +617,39 @@ class MySQL(Dialect): return self.expression(exp.Chr, **kwargs) + def _parse_group_concat(self) -> t.Optional[exp.Expression]: + def concat_exprs( + node: t.Optional[exp.Expression], exprs: t.List[exp.Expression] + ) -> exp.Expression: + if isinstance(node, exp.Distinct) and len(node.expressions) > 1: + concat_exprs = [ + self.expression(exp.Concat, expressions=node.expressions, safe=True) + ] + node.set("expressions", concat_exprs) + return node + if len(exprs) == 1: + return exprs[0] + return self.expression(exp.Concat, expressions=args, safe=True) + + args = self._parse_csv(self._parse_lambda) + + if args: + order = args[-1] if isinstance(args[-1], exp.Order) else None + + if order: + # Order By is the last (or only) expression in the list and has consumed the 'expr' before it, + # remove 'expr' from exp.Order and add it back to args + args[-1] = order.this + order.set("this", concat_exprs(order.this, args)) + + this = order or concat_exprs(args[0], args) + else: + this = None + + separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None + + return self.expression(exp.GroupConcat, this=this, separator=separator) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = None @@ -630,6 +663,7 @@ class MySQL(Dialect): JSON_TYPE_REQUIRED_FOR_EXTRACTION = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False JSON_KEY_VALUE_PAIR_SEP = "," + SUPPORTS_TO_NUMBER = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -637,9 +671,9 @@ class MySQL(Dialect): exp.DateDiff: _remove_ts_or_ds_to_date( lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") ), - exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")), + exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")), + exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")), exp.DateTrunc: _date_trunc_sql, exp.Day: _remove_ts_or_ds_to_date(), exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), @@ -672,7 +706,7 @@ class MySQL(Dialect): exp.TimeFromParts: rename_func("MAKETIME"), exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), exp.TimestampDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + "TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this ), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -682,9 +716,10 @@ class MySQL(Dialect): ), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsAdd: date_add_sql("ADD"), exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), exp.Week: _remove_ts_or_ds_to_date(), exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), exp.Year: _remove_ts_or_ds_to_date(), @@ -751,11 +786,6 @@ class MySQL(Dialect): result = f"{result} UNSIGNED" return result - def xor_sql(self, expression: exp.Xor) -> str: - if expression.expressions: - return self.expressions(expression, sep=" XOR ") - return super().xor_sql(expression) - def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str: return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index bccdad0..e038400 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( build_formatted_time, no_ilike_sql, rename_func, + to_number_with_nls_param, trim_sql, ) from sqlglot.helper import seq_get @@ -246,6 +247,7 @@ class Oracle(Dialect): exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY", exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.ToNumber: to_number_with_nls_param, exp.Trim: trim_sql, exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index b53ae07..11398ed 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -278,6 +278,7 @@ class Postgres(Dialect): "REVOKE": TokenType.COMMAND, "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, + "NAME": TokenType.NAME, "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, "OID": TokenType.OBJECT_IDENTIFIER, @@ -356,6 +357,16 @@ class Postgres(Dialect): JSON_ARROWS_REQUIRE_JSON_TYPE = True + COLUMN_OPERATORS = { + **parser.Parser.COLUMN_OPERATORS, + TokenType.ARROW: lambda self, this, path: build_json_extract_path( + exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE + )([this, path]), + TokenType.DARROW: lambda self, this, path: build_json_extract_path( + exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE + )([this, path]), + } + def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: while True: if not self._match(TokenType.L_PAREN): @@ -484,6 +495,7 @@ class Postgres(Dialect): ] ), exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 3649bd2..25bba96 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -22,10 +22,13 @@ from sqlglot.dialects.dialect import ( rename_func, right_to_substring_sql, struct_extract_sql, + str_position_sql, timestamptrunc_sql, timestrtotime_sql, ts_or_ds_add_cast, + unit_to_str, ) +from sqlglot.dialects.hive import Hive from sqlglot.dialects.mysql import MySQL from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -93,14 +96,14 @@ def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: expression = ts_or_ds_add_cast(expression) - unit = exp.Literal.string(expression.text("unit") or "DAY") + unit = unit_to_str(expression) return self.func("DATE_ADD", unit, expression.expression, expression.this) def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: this = exp.cast(expression.this, "TIMESTAMP") expr = exp.cast(expression.expression, "TIMESTAMP") - unit = exp.Literal.string(expression.text("unit") or "DAY") + unit = unit_to_str(expression) return self.func("DATE_DIFF", unit, expr, this) @@ -196,6 +199,7 @@ class Presto(Dialect): SUPPORTS_SEMI_ANTI_JOIN = False TYPED_DIVISION = True TABLESAMPLE_SIZE_IS_PERCENT = True + LOG_BASE_FIRST: t.Optional[bool] = None # https://github.com/trinodb/trino/issues/17 # https://github.com/trinodb/trino/issues/12289 @@ -289,6 +293,7 @@ class Presto(Dialect): SUPPORTS_SINGLE_ARG_CONCAT = False LIKE_PROPERTY_INSIDE_SCHEMA = True MULTI_ARG_DISTINCT = False + SUPPORTS_TO_NUMBER = False PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, @@ -323,6 +328,7 @@ class Presto(Dialect): exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), exp.ArraySize: rename_func("CARDINALITY"), + exp.ArrayToString: rename_func("ARRAY_JOIN"), exp.ArrayUniqueAgg: rename_func("SET_AGG"), exp.AtTimeZone: rename_func("AT_TIMEZONE"), exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression), @@ -339,19 +345,19 @@ class Presto(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "DAY"), + unit_to_str(e), _to_int(e.expression), e.this, ), exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.expression, e.this + "DATE_DIFF", unit_to_str(e), e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", exp.DateSub: lambda self, e: self.func( "DATE_ADD", - exp.Literal.string(e.text("unit") or "DAY"), + unit_to_str(e), _to_int(e.expression * -1), e.this, ), @@ -397,13 +403,10 @@ class Presto(Dialect): ] ), exp.SortArray: _no_sort_array, - exp.StrPosition: rename_func("STRPOS"), + exp.StrPosition: lambda self, e: str_position_sql(self, e, generate_instance=True), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, - exp.StrToUnix: lambda self, e: self.func( - "TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e)) - ), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.Timestamp: no_timestamp_sql, @@ -436,6 +439,22 @@ class Presto(Dialect): exp.Xor: bool_xor_sql, } + def strtounix_sql(self, expression: exp.StrToUnix) -> str: + # Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one. + # To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a + # timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback, + # which seems to be using the same time mapping as Hive, as per: + # https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html + value_as_text = exp.cast(expression.this, "text") + parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression)) + parse_with_tz = self.func( + "PARSE_DATETIME", + value_as_text, + self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE), + ) + coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz) + return self.func("TO_UNIXTIME", coalesced) + def bracket_sql(self, expression: exp.Bracket) -> str: if expression.args.get("safe"): return self.func( @@ -481,8 +500,7 @@ class Presto(Dialect): return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))" def interval_sql(self, expression: exp.Interval) -> str: - unit = self.sql(expression, "unit") - if expression.this and unit.startswith("WEEK"): + if expression.this and expression.text("unit").upper().startswith("WEEK"): return f"({expression.this.name} * INTERVAL '7' DAY)" return super().interval_sql(expression) diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py new file mode 100644 index 0000000..3005753 --- /dev/null +++ b/sqlglot/dialects/prql.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp, parser, tokens +from sqlglot.dialects.dialect import Dialect +from sqlglot.tokens import TokenType + + +class PRQL(Dialect): + class Tokenizer(tokens.Tokenizer): + IDENTIFIERS = ["`"] + QUOTES = ["'", '"'] + + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "=": TokenType.ALIAS, + "'": TokenType.QUOTE, + '"': TokenType.QUOTE, + "`": TokenType.IDENTIFIER, + "#": TokenType.COMMENT, + } + + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + } + + class Parser(parser.Parser): + TRANSFORM_PARSERS = { + "DERIVE": lambda self, query: self._parse_selection(query), + "SELECT": lambda self, query: self._parse_selection(query, append=False), + "TAKE": lambda self, query: self._parse_take(query), + } + + def _parse_statement(self) -> t.Optional[exp.Expression]: + expression = self._parse_expression() + expression = expression if expression else self._parse_query() + return expression + + def _parse_query(self) -> t.Optional[exp.Query]: + from_ = self._parse_from() + + if not from_: + return None + + query = exp.select("*").from_(from_, copy=False) + + while self._match_texts(self.TRANSFORM_PARSERS): + query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query) + + return query + + def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query: + if self._match(TokenType.L_BRACE): + selects = self._parse_csv(self._parse_expression) + + if not self._match(TokenType.R_BRACE, expression=query): + self.raise_error("Expecting }") + else: + expression = self._parse_expression() + selects = [expression] if expression else [] + + projections = { + select.alias_or_name: select.this if isinstance(select, exp.Alias) else select + for select in query.selects + } + + selects = [ + select.transform( + lambda s: (projections[s.name].copy() if s.name in projections else s) + if isinstance(s, exp.Column) + else s, + copy=False, + ) + for select in selects + ] + + return query.select(*selects, append=append, copy=False) + + def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]: + num = self._parse_number() # TODO: TAKE for ranges a..b + return query.limit(num) if num else None + + def _parse_expression(self) -> t.Optional[exp.Expression]: + if self._next and self._next.token_type == TokenType.ALIAS: + alias = self._parse_id_var(True) + self._match(TokenType.ALIAS) + return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias) + return self._parse_conjunction() + + def _parse_table( + self, + schema: bool = False, + joins: bool = False, + alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, + is_db_reference: bool = False, + ) -> t.Optional[exp.Expression]: + return self._parse_table_parts() + + def _parse_from( + self, joins: bool = False, skip_from_token: bool = False + ) -> t.Optional[exp.From]: + if not skip_from_token and not self._match(TokenType.FROM): + return None + + return self.expression( + exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) + ) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 0db87ec..1f0c411 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -92,23 +92,6 @@ class Redshift(Postgres): return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table - def _parse_types( - self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True - ) -> t.Optional[exp.Expression]: - this = super()._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - - if ( - isinstance(this, exp.DataType) - and this.is_type("varchar") - and this.expressions - and this.expressions[0].this == exp.column("MAX") - ): - this.set("expressions", [exp.var("MAX")]) - - return this - def _parse_convert( self, strict: bool, safe: t.Optional[bool] = None ) -> t.Optional[exp.Expression]: @@ -153,6 +136,7 @@ class Redshift(Postgres): NVL2_SUPPORTED = True LAST_DAY_SUPPORTS_DATE_PART = False CAN_IMPLEMENT_ARRAY_ANY = False + MULTI_ARG_DISTINCT = True TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -187,9 +171,13 @@ class Redshift(Postgres): ), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", + exp.StartsWith: lambda self, + e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'", exp.TableSample: no_tablesample_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD"), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), + exp.UnixToTime: lambda self, + e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')", } # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 20fdfb7..73a9166 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -20,8 +20,7 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, var_map_sql, ) -from sqlglot.expressions import Literal -from sqlglot.helper import flatten, is_int, seq_get +from sqlglot.helper import flatten, is_float, is_int, seq_get from sqlglot.tokens import TokenType if t.TYPE_CHECKING: @@ -29,33 +28,35 @@ if t.TYPE_CHECKING: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _build_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime, exp.TimeStrToTime]: - if len(args) == 2: - first_arg, second_arg = args - if second_arg.is_string: - # case: <string_expr> [ , <format> ] - return build_formatted_time(exp.StrToTime, "snowflake")(args) - return exp.UnixToTime(this=first_arg, scale=second_arg) +def _build_datetime( + name: str, kind: exp.DataType.Type, safe: bool = False +) -> t.Callable[[t.List], exp.Func]: + def _builder(args: t.List) -> exp.Func: + value = seq_get(args, 0) + + if isinstance(value, exp.Literal): + int_value = is_int(value.this) - from sqlglot.optimizer.simplify import simplify_literals + # Converts calls like `TO_TIME('01:02:03')` into casts + if len(args) == 1 and value.is_string and not int_value: + return exp.cast(value, kind) - # The first argument might be an expression like 40 * 365 * 86400, so we try to - # reduce it using `simplify_literals` first and then check if it's a Literal. - first_arg = seq_get(args, 0) - if not isinstance(simplify_literals(first_arg, root=True), Literal): - # case: <variant_expr> or other expressions such as columns - return exp.TimeStrToTime.from_arg_list(args) + # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special + # cases so we can transpile them, since they're relatively common + if kind == exp.DataType.Type.TIMESTAMP: + if int_value: + return exp.UnixToTime(this=value, scale=seq_get(args, 1)) + if not is_float(value.this): + return build_formatted_time(exp.StrToTime, "snowflake")(args) - if first_arg.is_string: - if is_int(first_arg.this): - # case: <integer> - return exp.UnixToTime.from_arg_list(args) + if len(args) == 2 and kind == exp.DataType.Type.DATE: + formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args) + formatted_exp.set("safe", safe) + return formatted_exp - # case: <date_expr> - return build_formatted_time(exp.StrToTime, "snowflake", default=True)(args) + return exp.Anonymous(this=name, expressions=args) - # case: <numeric_expr> - return exp.UnixToTime.from_arg_list(args) + return _builder def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: @@ -77,6 +78,17 @@ def _build_datediff(args: t.List) -> exp.DateDiff: ) +def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: + def _builder(args: t.List) -> E: + return expr_type( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=_map_date_part(seq_get(args, 0)), + ) + + return _builder + + # https://docs.snowflake.com/en/sql-reference/functions/div0 def _build_if_from_div0(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) @@ -97,14 +109,6 @@ def _build_if_from_nullifzero(args: t.List) -> exp.If: return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str: - if expression.is_type("array"): - return "ARRAY" - elif expression.is_type("map"): - return "OBJECT" - return self.datatype_sql(expression) - - def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: flag = expression.text("flag") @@ -258,6 +262,25 @@ def _unqualify_unpivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression: + assert isinstance(expression, exp.Create) + + def _flatten_structured_type(expression: exp.DataType) -> exp.DataType: + if expression.this in exp.DataType.NESTED_TYPES: + expression.set("expressions", None) + return expression + + props = expression.args.get("properties") + if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)): + for schema_expression in expression.this.expressions: + if isinstance(schema_expression, exp.ColumnDef): + column_type = schema_expression.kind + if isinstance(column_type, exp.DataType): + column_type.transform(_flatten_structured_type, copy=False) + + return expression + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE @@ -312,7 +335,13 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + ID_VAR_TOKENS = { + *parser.Parser.ID_VAR_TOKENS, + TokenType.MATCH_CONDITION, + } + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW} + TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION) FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -327,17 +356,13 @@ class Snowflake(Dialect): end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)), step=seq_get(args, 2), ), - "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, "BITXOR": binary_from_function(exp.BitwiseXor), "BIT_XOR": binary_from_function(exp.BitwiseXor), "BOOLXOR": binary_from_function(exp.Xor), "CONVERT_TIMEZONE": _build_convert_timezone, + "DATE": _build_datetime("DATE", exp.DataType.Type.DATE), "DATE_TRUNC": _date_trunc_to_time, - "DATEADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=_map_date_part(seq_get(args, 0)), - ), + "DATEADD": _build_date_time_add(exp.DateAdd), "DATEDIFF": _build_datediff, "DIV0": _build_if_from_div0, "FLATTEN": exp.Explode.from_arg_list, @@ -349,17 +374,34 @@ class Snowflake(Dialect): this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) ), "LISTAGG": exp.GroupConcat.from_arg_list, + "MEDIAN": lambda args: exp.PercentileCont( + this=seq_get(args, 0), expression=exp.Literal.number(0.5) + ), "NULLIFZERO": _build_if_from_nullifzero, "OBJECT_CONSTRUCT": _build_object_construct, "REGEXP_REPLACE": _build_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), + "TIMEADD": _build_date_time_add(exp.TimeAdd), "TIMEDIFF": _build_datediff, + "TIMESTAMPADD": _build_date_time_add(exp.DateAdd), "TIMESTAMPDIFF": _build_datediff, "TIMESTAMPFROMPARTS": _build_timestamp_from_parts, "TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts, - "TO_TIMESTAMP": _build_to_timestamp, + "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True), + "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE), + "TO_NUMBER": lambda args: exp.ToNumber( + this=seq_get(args, 0), + format=seq_get(args, 1), + precision=seq_get(args, 2), + scale=seq_get(args, 3), + ), + "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME), + "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP), + "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ), + "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP), + "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ), "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _build_if_from_zeroifnull, } @@ -377,7 +419,6 @@ class Snowflake(Dialect): **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), - TokenType.COLON: lambda self, this: self._parse_colon_get_path(this), } ALTER_PARSERS = { @@ -434,35 +475,35 @@ class Snowflake(Dialect): SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"} - def _parse_colon_get_path( - self: parser.Parser, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - while True: - path = self._parse_bitwise() or self._parse_var(any_token=True) + def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + this = super()._parse_column_ops(this) + + casts = [] + json_path = [] + + while self._match(TokenType.COLON): + path = super()._parse_column_ops(self._parse_field(any_token=True)) # The cast :: operator has a lower precedence than the extraction operator :, so # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH - if isinstance(path, exp.Cast): - target_type = path.to + while isinstance(path, exp.Cast): + casts.append(path.to) path = path.this - else: - target_type = None - if isinstance(path, exp.Expression): - path = exp.Literal.string(path.sql(dialect="snowflake")) + if path: + json_path.append(path.sql(dialect="snowflake", copy=False)) - # The extraction operator : is left-associative + if json_path: this = self.expression( - exp.JSONExtract, this=this, expression=self.dialect.to_json_path(path) + exp.JSONExtract, + this=this, + expression=self.dialect.to_json_path(exp.Literal.string(".".join(json_path))), ) - if target_type: - this = exp.cast(this, target_type) + while casts: + this = self.expression(exp.Cast, this=this, to=casts.pop()) - if not self._match(TokenType.COLON): - break - - return self._parse_range(this) + return this # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts @@ -663,6 +704,7 @@ class Snowflake(Dialect): "EXCLUDE": TokenType.EXCEPT, "ILIKE ANY": TokenType.ILIKE_ANY, "LIKE ANY": TokenType.LIKE_ANY, + "MATCH_CONDITION": TokenType.MATCH_CONDITION, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NCHAR VARYING": TokenType.VARCHAR, @@ -703,6 +745,7 @@ class Snowflake(Dialect): LIMIT_ONLY_LITERALS = True JSON_KEY_VALUE_PAIR_SEP = "," INSERT_OVERWRITE = " OVERWRITE INTO" + STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -711,15 +754,14 @@ class Snowflake(Dialect): exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this), - exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.AtTimeZone: lambda self, e: self.func( "CONVERT_TIMEZONE", e.args.get("zone"), e.this ), exp.BitwiseXor: rename_func("BITXOR"), + exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]), exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.DateStrToDate: datestrtodate_sql, - exp.DataType: _datatype_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -769,6 +811,7 @@ class Snowflake(Dialect): ), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.Stuff: rename_func("INSERT"), + exp.TimeAdd: date_delta_sql("TIMEADD"), exp.TimestampDiff: lambda self, e: self.func( "TIMESTAMPDIFF", e.unit, e.expression, e.this ), @@ -783,6 +826,9 @@ class Snowflake(Dialect): exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), + exp.TsOrDsToDate: lambda self, e: self.func( + "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e) + ), exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.WeekOfYear: rename_func("WEEKOFYEAR"), @@ -797,6 +843,8 @@ class Snowflake(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.NESTED: "OBJECT", + exp.DataType.Type.STRUCT: "OBJECT", exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -811,6 +859,37 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + UNSUPPORTED_VALUES_EXPRESSIONS = { + exp.Struct, + } + + def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: + if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS): + values_as_table = False + + return super().values_sql(expression, values_as_table=values_as_table) + + def datatype_sql(self, expression: exp.DataType) -> str: + expressions = expression.expressions + if ( + expressions + and expression.is_type(*exp.DataType.STRUCT_TYPES) + and any(isinstance(field_type, exp.DataType) for field_type in expressions) + ): + # The correct syntax is OBJECT [ (<key> <value_type [NOT NULL] [, ...]) ] + return "OBJECT" + + return super().datatype_sql(expression) + + def tonumber_sql(self, expression: exp.ToNumber) -> str: + return self.func( + "TO_NUMBER", + expression.this, + expression.args.get("format"), + expression.args.get("precision"), + expression.args.get("scale"), + ) + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: milli = expression.args.get("milli") if milli is not None: diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 20c0fce..88b5ddc 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import exp -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import rename_func, unit_to_var from sqlglot.dialects.hive import _build_with_ignore_nulls from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider from sqlglot.helper import seq_get @@ -78,6 +78,8 @@ class Spark(Spark2): return this class Generator(Spark2.Generator): + SUPPORTS_TO_NUMBER = True + TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, exp.DataType.Type.MONEY: "DECIMAL(15, 4)", @@ -100,7 +102,7 @@ class Spark(Spark2): e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", exp.StartsWith: rename_func("STARTSWITH"), exp.TimestampAdd: lambda self, e: self.func( - "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this + "DATEADD", unit_to_var(e), e.expression, e.this ), exp.TryCast: lambda self, e: ( self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) @@ -117,11 +119,10 @@ class Spark(Spark2): return self.function_fallback_sql(expression) def datediff_sql(self, expression: exp.DateDiff) -> str: - unit = self.sql(expression, "unit") end = self.sql(expression, "this") start = self.sql(expression, "expression") - if unit: - return self.func("DATEDIFF", unit, start, end) + if expression.unit: + return self.func("DATEDIFF", unit_to_var(expression), start, end) return self.func("DATEDIFF", end, start) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 63eae6e..069916f 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( pivot_column_names, rename_func, trim_sql, + unit_to_str, ) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get @@ -203,6 +204,7 @@ class Spark2(Hive): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.ArrayToString: rename_func("ARRAY_JOIN"), exp.AtTimeZone: lambda self, e: self.func( "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone") ), @@ -218,7 +220,7 @@ class Spark2(Hive): ] ), exp.DateFromParts: rename_func("MAKE_DATE"), - exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), + exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)), exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -241,9 +243,7 @@ class Spark2(Hive): ), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this - ), + exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), exp.Trim: trim_sql, exp.UnixToTime: _unix_to_time_sql, exp.VariancePop: rename_func("VAR_POP"), @@ -252,7 +252,6 @@ class Spark2(Hive): [transforms.remove_within_group_for_percentiles] ), } - TRANSFORMS.pop(exp.ArrayJoin) TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) TRANSFORMS.pop(exp.Left) diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 2b17ff9..ef7d9aa 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -33,6 +33,14 @@ def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> st return arrow_json_extract_sql(self, expression) +def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr: + if len(args) == 1: + args.append(exp.CurrentTimestamp()) + if len(args) == 2: + return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0]) + return exp.Anonymous(this="STRFTIME", expressions=args) + + def _transform_create(expression: exp.Expression) -> exp.Expression: """Move primary key to a column and enforce auto_increment on primary keys.""" schema = expression.this @@ -82,6 +90,7 @@ class SQLite(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, + "STRFTIME": _build_strftime, } STRING_ALIASES = True @@ -93,6 +102,7 @@ class SQLite(Dialect): JSON_PATH_BRACKETED_KEY_SUPPORTED = False SUPPORTS_CREATE_TABLE_LIKE = False SUPPORTS_TABLE_ALIAS_COLUMNS = False + SUPPORTS_TO_NUMBER = False SUPPORTED_JSON_PATH_PARTS = { exp.JSONPathKey, @@ -151,7 +161,9 @@ class SQLite(Dialect): ), exp.TableSample: no_tablesample_sql, exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), + exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this), exp.TryCast: no_trycast_sql, + exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"), } # SQLite doesn't generally support CREATE TABLE .. properties diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 12ac600..5691f58 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, build_timestamp_trunc, rename_func, + unit_to_str, ) from sqlglot.dialects.mysql import MySQL from sqlglot.helper import seq_get @@ -39,15 +40,13 @@ class StarRocks(MySQL): **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.text("unit") or "DAY"), e.this, e.expression + "DATE_DIFF", unit_to_str(e), e.this, e.expression ), exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.RegexpLike: rename_func("REGEXP"), exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this - ), + exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), exp.TimeStrToDate: rename_func("TO_DATE"), exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), exp.UnixToTime: rename_func("FROM_UNIXTIME"), diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index b736918..40feb67 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -5,6 +5,8 @@ from sqlglot.dialects.dialect import Dialect, rename_func class Tableau(Dialect): + LOG_BASE_FIRST = False + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = [("[", "]")] QUOTES = ["'", '"'] diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 0663a1d..a65e10e 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -3,7 +3,13 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func +from sqlglot.dialects.dialect import ( + Dialect, + max_or_greatest, + min_or_least, + rename_func, + to_number_with_nls_param, +) from sqlglot.tokens import TokenType @@ -206,6 +212,7 @@ class Teradata(Dialect): exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.ToNumber: to_number_with_nls_param, exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 1bbed67..457e2f0 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -7,6 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): SUPPORTS_USER_DEFINED_TYPES = False + LOG_BASE_FIRST = True class Generator(Presto.Generator): TRANSFORMS = { diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b6f491f..8e06be6 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( NormalizationStrategy, any_value_to_max_sql, date_delta_sql, + datestrtodate_sql, generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, @@ -724,6 +725,7 @@ class TSQL(Dialect): TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" SUPPORTS_SELECT_INTO = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False + SUPPORTS_TO_NUMBER = False EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -760,12 +762,14 @@ class TSQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, + exp.ArrayToString: rename_func("STRING_AGG"), exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", exp.DateAdd: date_delta_sql("DATEADD"), exp.DateDiff: date_delta_sql("DATEDIFF"), exp.CTE: transforms.preprocess([qualify_derived_table_outputs]), exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), + exp.DateStrToDate: datestrtodate_sql, exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, @@ -808,6 +812,22 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def select_sql(self, expression: exp.Select) -> str: + if expression.args.get("offset"): + if not expression.args.get("order"): + # ORDER BY is required in order to use OFFSET in a query, so we use + # a noop order by, since we don't really care about the order. + # See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819 + expression.order_by(exp.select(exp.null()).subquery(), copy=False) + + limit = expression.args.get("limit") + if isinstance(limit, exp.Limit): + # TOP and OFFSET can't be combined, we need use FETCH instead of TOP + # we replace here because otherwise TOP would be generated in select_sql + limit.replace(exp.Fetch(direction="FIRST", count=limit.expression)) + + return super().select_sql(expression) + def convert_sql(self, expression: exp.Convert) -> str: name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT" return self.func( @@ -862,12 +882,12 @@ class TSQL(Dialect): return rename_func("DATETIMEFROMPARTS")(self, expression) - def set_operation(self, expression: exp.Union, op: str) -> str: + def set_operations(self, expression: exp.Union) -> str: limit = expression.args.get("limit") if limit: return self.sql(expression.limit(limit.pop(), copy=False)) - return super().set_operation(expression, op) + return super().set_operations(expression) def setitem_sql(self, expression: exp.SetItem) -> str: this = expression.this diff --git a/sqlglot/diff.py b/sqlglot/diff.py index bda9136..22c506a 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -103,7 +103,7 @@ def diff( ) -> t.Dict[int, exp.Expression]: return { id(old_node): new_node - for (old_node, _, _), (new_node, _, _) in zip(original.walk(), copy.walk()) + for old_node, new_node in zip(original.walk(), copy.walk()) if id(old_node) in matching_ids } @@ -158,14 +158,10 @@ class ChangeDistiller: self._source = source self._target = target self._source_index = { - id(n): n - for n, *_ in self._source.bfs() - if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) + id(n): n for n in self._source.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) } self._target_index = { - id(n): n - for n, *_ in self._target.bfs() - if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) + id(n): n for n in self._target.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) } self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) @@ -216,10 +212,10 @@ class ChangeDistiller: matching_set = leaves_matching_set.copy() ordered_unmatched_source_nodes = { - id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes + id(n): None for n in self._source.bfs() if id(n) in self._unmatched_source_nodes } ordered_unmatched_target_nodes = { - id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes + id(n): None for n in self._target.bfs() if id(n) in self._unmatched_target_nodes } for source_node_id in ordered_unmatched_source_nodes: @@ -322,7 +318,7 @@ class ChangeDistiller: def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: has_child_exprs = False - for _, node in expression.iter_expressions(): + for node in expression.iter_expressions(): if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES): has_child_exprs = True yield from _get_leaves(node) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index c8f9148..29c0e68 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -10,11 +10,13 @@ import logging import time import typing as t +from sqlglot import exp from sqlglot.errors import ExecuteError from sqlglot.executor.python import PythonExecutor from sqlglot.executor.table import Table, ensure_tables from sqlglot.helper import dict_depth from sqlglot.optimizer import optimize +from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.planner import Plan from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set @@ -26,15 +28,11 @@ if t.TYPE_CHECKING: from sqlglot.schema import Schema -PYTHON_TYPE_TO_SQLGLOT = { - "dict": "MAP", -} - - def execute( sql: str | Expression, schema: t.Optional[t.Dict | Schema] = None, read: DialectType = None, + dialect: DialectType = None, tables: t.Optional[t.Dict] = None, ) -> Table: """ @@ -48,11 +46,13 @@ def execute( 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). + dialect: the SQL dialect (alias for read). tables: additional tables to register. Returns: Simple columnar data structure. """ + read = read or dialect tables_ = ensure_tables(tables, dialect=read) if not schema: @@ -64,8 +64,9 @@ def execute( assert table is not None for column in table.columns: - py_type = type(table[0][column]).__name__ - nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type) + value = table[0][column] + column_type = annotate_types(exp.convert(value)).type or type(value).__name__ + nested_set(schema, [*keys, column], column_type) schema = ensure_schema(schema, dialect=read) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 218a8e0..c51049b 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -106,6 +106,13 @@ def cast(this, to): return this if isinstance(this, str): return datetime.date.fromisoformat(this) + if to == exp.DataType.Type.TIME: + if isinstance(this, datetime.datetime): + return this.time() + if isinstance(this, datetime.time): + return this + if isinstance(this, str): + return datetime.time.fromisoformat(this) if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP): if isinstance(this, datetime.datetime): return this @@ -139,7 +146,7 @@ def interval(this, unit): @null_if_any("this", "expression") -def arrayjoin(this, expression, null=None): +def arraytostring(this, expression, null=None): return expression.join(x for x in (x if x is not None else null for x in this) if x is not None) @@ -173,7 +180,7 @@ ENV = { "ABS": null_if_any(lambda this: abs(this)), "ADD": null_if_any(lambda e, this: e + this), "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)), - "ARRAYJOIN": arrayjoin, + "ARRAYTOSTRING": arraytostring, "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), "BITWISEAND": null_if_any(lambda this, e: this & e), "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), @@ -212,6 +219,7 @@ ENV = { "ORDERED": ordered, "POW": pow, "RIGHT": null_if_any(lambda this, e: this[-e:]), + "ROUND": null_if_any(lambda this, decimals=None, truncate=None: round(this, ndigits=decimals)), "STRPOSITION": str_position, "SUB": null_if_any(lambda e, this: e - this), "SUBSTRING": substring, @@ -225,10 +233,12 @@ ENV = { "CURRENTTIME": datetime.datetime.now, "CURRENTDATE": datetime.date.today, "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)), + "STRTOTIME": null_if_any(lambda arg, format: datetime.datetime.strptime(arg, format)), "TRIM": null_if_any(lambda this, e=None: this.strip(e)), "STRUCT": lambda *args: { args[x]: args[x + 1] for x in range(0, len(args), 2) if (args[x + 1] is not None and args[x] is not None) }, + "UNIXTOTIME": null_if_any(lambda arg: datetime.datetime.utcfromtimestamp(arg)), } diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index a2b23d4..674ef78 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -157,7 +157,7 @@ class PythonExecutor: yield context.table.reader def join(self, step, context): - source = step.name + source = step.source_name source_table = context.tables[source] source_context = self.context({source: source_table}) @@ -398,7 +398,7 @@ def _lambda_sql(self, e: exp.Lambda) -> str: lambda n: ( exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n ) - ) + ).assert_is(exp.Lambda) return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1a24875..e79c04b 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -39,6 +39,8 @@ if t.TYPE_CHECKING: from sqlglot._typing import E, Lit from sqlglot.dialects.dialect import DialectType + Q = t.TypeVar("Q", bound="Query") + class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -72,6 +74,7 @@ class Expression(metaclass=_Expression): parent: a reference to the parent expression (or None, in case of root expressions). arg_key: the arg key an expression is associated with, i.e. the name its parent expression uses to refer to it. + index: the index of an expression if it is inside of a list argument in its parent. comments: a list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code. type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the @@ -91,12 +94,13 @@ class Expression(metaclass=_Expression): key = "expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash") + __slots__ = ("args", "parent", "arg_key", "index", "comments", "_type", "_meta", "_hash") def __init__(self, **args: t.Any): self.args: t.Dict[str, t.Any] = args self.parent: t.Optional[Expression] = None self.arg_key: t.Optional[str] = None + self.index: t.Optional[int] = None self.comments: t.Optional[t.List[str]] = None self._type: t.Optional[DataType] = None self._meta: t.Optional[t.Dict[str, t.Any]] = None @@ -248,25 +252,44 @@ class Expression(metaclass=_Expression): return self._meta def __deepcopy__(self, memo): - copy = self.__class__(**deepcopy(self.args)) - if self.comments is not None: - copy.comments = deepcopy(self.comments) - - if self._type is not None: - copy._type = self._type.copy() - - if self._meta is not None: - copy._meta = deepcopy(self._meta) - - return copy + root = self.__class__() + stack = [(self, root)] + + while stack: + node, copy = stack.pop() + + if node.comments is not None: + copy.comments = deepcopy(node.comments) + if node._type is not None: + copy._type = deepcopy(node._type) + if node._meta is not None: + copy._meta = deepcopy(node._meta) + if node._hash is not None: + copy._hash = node._hash + + for k, vs in node.args.items(): + if hasattr(vs, "parent"): + stack.append((vs, vs.__class__())) + copy.set(k, stack[-1][-1]) + elif type(vs) is list: + copy.args[k] = [] + + for v in vs: + if hasattr(v, "parent"): + stack.append((v, v.__class__())) + copy.append(k, stack[-1][-1]) + else: + copy.append(k, v) + else: + copy.args[k] = vs + + return root def copy(self): """ Returns a deep copy of the expression. """ - new = deepcopy(self) - new.parent = self.parent - return new + return deepcopy(self) def add_comments(self, comments: t.Optional[t.List[str]]) -> None: if self.comments is None: @@ -289,35 +312,59 @@ class Expression(metaclass=_Expression): arg_key (str): name of the list expression arg value (Any): value to append to the list """ - if not isinstance(self.args.get(arg_key), list): + if type(self.args.get(arg_key)) is not list: self.args[arg_key] = [] - self.args[arg_key].append(value) self._set_parent(arg_key, value) + values = self.args[arg_key] + if hasattr(value, "parent"): + value.index = len(values) + values.append(value) - def set(self, arg_key: str, value: t.Any) -> None: + def set(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None: """ Sets arg_key to value. Args: arg_key: name of the expression arg. value: value to set the arg to. - """ - if value is None: + index: if the arg is a list, this specifies what position to add the value in it. + """ + if index is not None: + expressions = self.args.get(arg_key) or [] + + if seq_get(expressions, index) is None: + return + if value is None: + expressions.pop(index) + for v in expressions[index:]: + v.index = v.index - 1 + return + + if isinstance(value, list): + expressions.pop(index) + expressions[index:index] = value + else: + expressions[index] = value + + value = expressions + elif value is None: self.args.pop(arg_key, None) return self.args[arg_key] = value - self._set_parent(arg_key, value) + self._set_parent(arg_key, value, index) - def _set_parent(self, arg_key: str, value: t.Any) -> None: + def _set_parent(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None: if hasattr(value, "parent"): value.parent = self value.arg_key = arg_key + value.index = index elif type(value) is list: - for v in value: + for index, v in enumerate(value): if hasattr(v, "parent"): v.parent = self v.arg_key = arg_key + v.index = index @property def depth(self) -> int: @@ -328,16 +375,17 @@ class Expression(metaclass=_Expression): return self.parent.depth + 1 return 0 - def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]: + def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]: """Yields the key and expression for all arguments, exploding list args.""" - for k, vs in self.args.items(): + # remove tuple when python 3.7 is deprecated + for vs in reversed(tuple(self.args.values())) if reverse else self.args.values(): if type(vs) is list: - for v in vs: + for v in reversed(vs) if reverse else vs: if hasattr(v, "parent"): - yield k, v + yield v else: if hasattr(vs, "parent"): - yield k, vs + yield vs def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]: """ @@ -365,7 +413,7 @@ class Expression(metaclass=_Expression): Returns: The generator object. """ - for expression, *_ in self.walk(bfs=bfs): + for expression in self.walk(bfs=bfs): if isinstance(expression, expression_types): yield expression @@ -405,15 +453,17 @@ class Expression(metaclass=_Expression): expression = expression.parent return expression - def walk(self, bfs=True, prune=None): + def walk( + self, bfs: bool = True, prune: t.Optional[t.Callable[[Expression], bool]] = None + ) -> t.Iterator[Expression]: """ Returns a generator object which visits all nodes in this tree. Args: - bfs (bool): if set to True the BFS traversal order will be applied, + bfs: if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead. - prune ((node, parent, arg_key) -> bool): callable that returns True if - the generator should stop traversing this branch of the tree. + prune: callable that returns True if the generator should stop traversing + this branch of the tree. Returns: the generator object. @@ -423,7 +473,9 @@ class Expression(metaclass=_Expression): else: yield from self.dfs(prune=prune) - def dfs(self, parent=None, key=None, prune=None): + def dfs( + self, prune: t.Optional[t.Callable[[Expression], bool]] = None + ) -> t.Iterator[Expression]: """ Returns a generator object which visits all nodes in this tree in the DFS (Depth-first) order. @@ -431,15 +483,22 @@ class Expression(metaclass=_Expression): Returns: The generator object. """ - parent = parent or self.parent - yield self, parent, key - if prune and prune(self, parent, key): - return + stack = [self] + + while stack: + node = stack.pop() - for k, v in self.iter_expressions(): - yield from v.dfs(self, k, prune) + yield node - def bfs(self, prune=None): + if prune and prune(node): + continue + + for v in node.iter_expressions(reverse=True): + stack.append(v) + + def bfs( + self, prune: t.Optional[t.Callable[[Expression], bool]] = None + ) -> t.Iterator[Expression]: """ Returns a generator object which visits all nodes in this tree in the BFS (Breadth-first) order. @@ -447,17 +506,18 @@ class Expression(metaclass=_Expression): Returns: The generator object. """ - queue = deque([(self, self.parent, None)]) + queue = deque([self]) while queue: - item, parent, key = queue.popleft() + node = queue.popleft() - yield item, parent, key - if prune and prune(item, parent, key): + yield node + + if prune and prune(node): continue - for k, v in item.iter_expressions(): - queue.append((v, item, k)) + for v in node.iter_expressions(): + queue.append(v) def unnest(self): """ @@ -480,7 +540,7 @@ class Expression(metaclass=_Expression): """ Returns unnested operands as a tuple. """ - return tuple(arg.unnest() for _, arg in self.iter_expressions()) + return tuple(arg.unnest() for arg in self.iter_expressions()) def flatten(self, unnest=True): """ @@ -488,7 +548,7 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs(prune=lambda n, p, *_: p and type(n) is not self.__class__): + for node in self.dfs(prune=lambda n: n.parent and type(n) is not self.__class__): if type(node) is not self.__class__: yield node.unnest() if unnest and not isinstance(node, Subquery) else node @@ -520,32 +580,35 @@ class Expression(metaclass=_Expression): return Dialect.get_or_raise(dialect).generate(self, **opts) - def transform(self, fun, *args, copy=True, **kwargs): + def transform(self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs) -> Expression: """ - Recursively visits all tree nodes (excluding already transformed ones) + Visits all tree nodes (excluding already transformed ones) and applies the given transformation function to each node. Args: - fun (function): a function which takes a node as an argument and returns a + fun: a function which takes a node as an argument and returns a new transformed node or the same node without modifications. If the function returns None, then the corresponding node will be removed from the syntax tree. - copy (bool): if set to True a new tree instance is constructed, otherwise the tree is + copy: if set to True a new tree instance is constructed, otherwise the tree is modified in place. Returns: The transformed tree. """ - node = self.copy() if copy else self - new_node = fun(node, *args, **kwargs) + root = None + new_node = None - if new_node is None or not isinstance(new_node, Expression): - return new_node - if new_node is not node: - new_node.parent = node.parent - return new_node + for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node): + parent, arg_key, index = node.parent, node.arg_key, node.index + new_node = fun(node, *args, **kwargs) + + if not root: + root = new_node + elif new_node is not node: + parent.set(arg_key, new_node, index) - replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)) - return new_node + assert root + return root.assert_is(Expression) @t.overload def replace(self, expression: E) -> E: ... @@ -572,13 +635,26 @@ class Expression(metaclass=_Expression): Returns: The new expression or expressions. """ - if not self.parent: + parent = self.parent + + if not parent or parent is expression: return expression - parent = self.parent - self.parent = None + key = self.arg_key + value = parent.args.get(key) + + if type(expression) is list and isinstance(value, Expression): + # We are trying to replace an Expression with a list, so it's assumed that + # the intention was to really replace the parent of this expression. + value.parent.replace(expression) + else: + parent.set(key, expression, self.index) + + if expression is not self: + self.parent = None + self.arg_key = None + self.index = None - replace_children(parent, lambda child: expression if child is self else child) return expression def pop(self: E) -> E: @@ -816,6 +892,9 @@ class Expression(metaclass=_Expression): div.args["safe"] = safe return div + def asc(self, nulls_first: bool = True) -> Ordered: + return Ordered(this=self.copy(), nulls_first=nulls_first) + def desc(self, nulls_first: bool = False) -> Ordered: return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first) @@ -983,13 +1062,13 @@ class Query(Expression): raise NotImplementedError("Query objects must implement `named_selects`") def select( - self, + self: Q, *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, **opts, - ) -> Query: + ) -> Q: """ Append to or set the SELECT expressions. @@ -1012,7 +1091,7 @@ class Query(Expression): raise NotImplementedError("Query objects must implement `select`") def with_( - self, + self: Q, alias: ExpOrStr, as_: ExpOrStr, recursive: t.Optional[bool] = None, @@ -1020,7 +1099,7 @@ class Query(Expression): dialect: DialectType = None, copy: bool = True, **opts, - ) -> Query: + ) -> Q: """ Append to or set the common table expressions. @@ -1222,6 +1301,18 @@ class Create(DDL): return kind and kind.upper() +class SequenceProperties(Expression): + arg_types = { + "increment": False, + "minvalue": False, + "maxvalue": False, + "cache": False, + "start": False, + "owned": False, + "options": False, + } + + class TruncateTable(Expression): arg_types = { "expressions": True, @@ -1243,7 +1334,7 @@ class Clone(Expression): class Describe(Expression): - arg_types = {"this": True, "extended": False, "kind": False, "expressions": False} + arg_types = {"this": True, "style": False, "kind": False, "expressions": False} class Kill(Expression): @@ -1321,7 +1412,12 @@ class WithinGroup(Expression): # clickhouse supports scalar ctes # https://clickhouse.com/docs/en/sql-reference/statements/select/with class CTE(DerivedTable): - arg_types = {"this": True, "alias": True, "scalar": False} + arg_types = { + "this": True, + "alias": True, + "scalar": False, + "materialized": False, + } class TableAlias(Expression): @@ -1541,6 +1637,15 @@ class EncodeColumnConstraint(ColumnConstraintKind): pass +# https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-EXCLUDE +class ExcludeColumnConstraint(ColumnConstraintKind): + pass + + +class WithOperator(Expression): + arg_types = {"this": True, "op": True} + + class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = { @@ -1560,13 +1665,16 @@ class GeneratedAsRowColumnConstraint(ColumnConstraintKind): # https://dev.mysql.com/doc/refman/8.0/en/create-table.html +# https://github.com/ClickHouse/ClickHouse/blob/master/src/Parsers/ParserCreateQuery.h#L646 class IndexColumnConstraint(ColumnConstraintKind): arg_types = { "this": False, - "schema": True, + "expressions": False, "kind": False, "index_type": False, "options": False, + "expression": False, # Clickhouse + "granularity": False, } @@ -1605,7 +1713,7 @@ class TitleColumnConstraint(ColumnConstraintKind): class UniqueColumnConstraint(ColumnConstraintKind): - arg_types = {"this": False, "index_type": False} + arg_types = {"this": False, "index_type": False, "on_conflict": False} class UppercaseColumnConstraint(ColumnConstraintKind): @@ -1714,6 +1822,7 @@ class Drop(Expression): arg_types = { "this": False, "kind": False, + "expressions": False, "exists": False, "temporary": False, "materialized": False, @@ -1733,7 +1842,7 @@ class Check(Expression): # https://docs.snowflake.com/en/sql-reference/constructs/connect-by class Connect(Expression): - arg_types = {"start": False, "connect": True} + arg_types = {"start": False, "connect": True, "nocycle": False} class Prior(Expression): @@ -1815,20 +1924,30 @@ class Index(Expression): arg_types = { "this": False, "table": False, - "using": False, - "where": False, - "columns": False, "unique": False, "primary": False, "amp": False, # teradata + "params": False, + } + + +class IndexParameters(Expression): + arg_types = { + "using": False, "include": False, - "partition_by": False, # teradata + "columns": False, + "with_storage": False, + "partition_by": False, + "tablespace": False, + "where": False, } class Insert(DDL, DML): arg_types = { + "hint": False, "with": False, + "is_function": False, "this": True, "expression": False, "conflict": False, @@ -1883,8 +2002,8 @@ class OnConflict(Expression): arg_types = { "duplicate": False, "expressions": False, - "nothing": False, - "key": False, + "action": False, + "conflict_keys": False, "constraint": False, } @@ -1981,6 +2100,7 @@ class Join(Expression): "method": False, "global": False, "hint": False, + "match_condition": False, # Snowflake } @property @@ -2173,6 +2293,10 @@ class AutoRefreshProperty(Property): arg_types = {"this": True} +class BackupProperty(Property): + arg_types = {"this": True} + + class BlockCompressionProperty(Property): arg_types = { "autotemp": False, @@ -2253,6 +2377,14 @@ class FreespaceProperty(Property): arg_types = {"this": True, "percent": False} +class GlobalProperty(Property): + arg_types = {} + + +class IcebergProperty(Property): + arg_types = {} + + class InheritsProperty(Property): arg_types = {"expressions": True} @@ -2266,13 +2398,7 @@ class OutputModelProperty(Property): class IsolatedLoadingProperty(Property): - arg_types = { - "no": False, - "concurrent": False, - "for_all": False, - "for_insert": False, - "for_none": False, - } + arg_types = {"no": False, "concurrent": False, "target": False} class JournalProperty(Property): @@ -2436,6 +2562,10 @@ class SetProperty(Property): arg_types = {"multi": True} +class SharingProperty(Property): + arg_types = {"this": False} + + class SetConfigProperty(Property): arg_types = {"this": True} @@ -2472,6 +2602,15 @@ class TransientProperty(Property): arg_types = {"this": False} +class UnloggedProperty(Property): + arg_types = {} + + +# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-view-transact-sql?view=sql-server-ver16 +class ViewAttributeProperty(Property): + arg_types = {"this": True} + + class VolatileProperty(Property): arg_types = {"this": False} @@ -3630,6 +3769,10 @@ class SessionParameter(Condition): class Placeholder(Condition): arg_types = {"this": False, "kind": False} + @property + def name(self) -> str: + return self.this or "?" + class Null(Condition): arg_types: t.Dict[str, t.Any] = {} @@ -3714,6 +3857,7 @@ class DataType(Expression): MEDIUMINT = auto() MEDIUMTEXT = auto() MONEY = auto() + NAME = auto() NCHAR = auto() NESTED = auto() NULL = auto() @@ -3764,47 +3908,85 @@ class DataType(Expression): XML = auto() YEAR = auto() + STRUCT_TYPES = { + Type.NESTED, + Type.OBJECT, + Type.STRUCT, + } + + NESTED_TYPES = { + *STRUCT_TYPES, + Type.ARRAY, + Type.MAP, + } + TEXT_TYPES = { Type.CHAR, Type.NCHAR, - Type.VARCHAR, Type.NVARCHAR, Type.TEXT, + Type.VARCHAR, + Type.NAME, } - INTEGER_TYPES = { - Type.INT, - Type.TINYINT, - Type.SMALLINT, + SIGNED_INTEGER_TYPES = { Type.BIGINT, + Type.INT, Type.INT128, Type.INT256, + Type.MEDIUMINT, + Type.SMALLINT, + Type.TINYINT, + } + + UNSIGNED_INTEGER_TYPES = { + Type.UBIGINT, + Type.UINT, + Type.UINT128, + Type.UINT256, + Type.UMEDIUMINT, + Type.USMALLINT, + Type.UTINYINT, + } + + INTEGER_TYPES = { + *SIGNED_INTEGER_TYPES, + *UNSIGNED_INTEGER_TYPES, Type.BIT, } FLOAT_TYPES = { - Type.FLOAT, Type.DOUBLE, + Type.FLOAT, + } + + REAL_TYPES = { + *FLOAT_TYPES, + Type.BIGDECIMAL, + Type.DECIMAL, + Type.MONEY, + Type.SMALLMONEY, + Type.UDECIMAL, } NUMERIC_TYPES = { *INTEGER_TYPES, - *FLOAT_TYPES, + *REAL_TYPES, } TEMPORAL_TYPES = { + Type.DATE, + Type.DATE32, + Type.DATETIME, + Type.DATETIME64, Type.TIME, - Type.TIMETZ, Type.TIMESTAMP, - Type.TIMESTAMPTZ, Type.TIMESTAMPLTZ, - Type.TIMESTAMP_S, + Type.TIMESTAMPTZ, Type.TIMESTAMP_MS, Type.TIMESTAMP_NS, - Type.DATE, - Type.DATE32, - Type.DATETIME, - Type.DATETIME64, + Type.TIMESTAMP_S, + Type.TIMETZ, } @classmethod @@ -4163,8 +4345,6 @@ class Not(Unary): class Paren(Unary): - arg_types = {"this": True, "with": False} - @property def output_name(self) -> str: return self.this.name @@ -4277,7 +4457,7 @@ class TimeUnit(Expression): super().__init__(**args) @property - def unit(self) -> t.Optional[Var]: + def unit(self) -> t.Optional[Var | IntervalSpan]: return self.args.get("unit") @@ -4451,6 +4631,18 @@ class ToChar(Func): arg_types = {"this": True, "format": False, "nlsparam": False} +# https://docs.snowflake.com/en/sql-reference/functions/to_decimal +# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_NUMBER.html +class ToNumber(Func): + arg_types = { + "this": True, + "format": False, + "nlsparam": False, + "precision": False, + "scale": False, + } + + # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax class Convert(Func): arg_types = {"this": True, "expression": True, "style": False} @@ -4496,8 +4688,9 @@ class ArrayFilter(Func): _sql_names = ["FILTER", "ARRAY_FILTER"] -class ArrayJoin(Func): +class ArrayToString(Func): arg_types = {"this": True, "expression": True, "null": False} + _sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"] class ArrayOverlaps(Binary, Func): @@ -4580,7 +4773,13 @@ class Case(Func): class Cast(Func): - arg_types = {"this": True, "to": True, "format": False, "safe": False} + arg_types = { + "this": True, + "to": True, + "format": False, + "safe": False, + "action": False, + } @property def name(self) -> str: @@ -4889,6 +5088,10 @@ class ToBase64(Func): pass +class GenerateDateArray(Func): + arg_types = {"start": True, "end": True, "interval": False} + + class Greatest(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -5142,14 +5345,6 @@ class Log(Func): arg_types = {"this": True, "expression": False} -class Log2(Func): - pass - - -class Log10(Func): - pass - - class LogicalOr(AggFunc): _sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"] @@ -5176,6 +5371,11 @@ class Map(Func): return values.expressions if values else [] +# Represents the MAP {...} syntax in DuckDB - basically convert a struct to a MAP +class ToMap(Func): + pass + + class MapFromEntries(Func): pass @@ -5501,13 +5701,17 @@ class TsOrDsToDateStr(Func): class TsOrDsToDate(Func): - arg_types = {"this": True, "format": False} + arg_types = {"this": True, "format": False, "safe": False} class TsOrDsToTime(Func): pass +class TsOrDsToTimestamp(Func): + pass + + class TsOrDiToDi(Func): pass @@ -5528,7 +5732,14 @@ class UnixToStr(Func): # https://prestodb.io/docs/current/functions/datetime.html # presto has weird zone/hours/minutes class UnixToTime(Func): - arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False} + arg_types = { + "this": True, + "scale": False, + "zone": False, + "hours": False, + "minutes": False, + "format": False, + } SECONDS = Literal.number(0) DECIS = Literal.number(1) @@ -5565,6 +5776,10 @@ class Upper(Func): _sql_names = ["UPPER", "UCASE"] +class Corr(Binary, AggFunc): + pass + + class Variance(AggFunc): _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"] @@ -5573,6 +5788,14 @@ class VariancePop(AggFunc): _sql_names = ["VARIANCE_POP", "VAR_POP"] +class CovarSamp(Binary, AggFunc): + pass + + +class CovarPop(Binary, AggFunc): + pass + + class Week(Func): arg_types = {"this": True, "mode": False} @@ -6516,7 +6739,7 @@ def subquery( **opts, ) -> Select: """ - Build a subquery expression. + Build a subquery expression that's selected from. Example: >>> subquery('select x from tbl', 'bar').select('x').sql() @@ -6766,7 +6989,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression: copy: Whether to copy `value` (only applies to Expressions and collections). Returns: - Expression: the equivalent expression object. + The equivalent expression object. """ if isinstance(value, Expression): return maybe_copy(value, copy) @@ -6778,15 +7001,28 @@ def convert(value: t.Any, copy: bool = False) -> Expression: return null() if isinstance(value, numbers.Number): return Literal.number(value) + if isinstance(value, bytes): + return HexString(this=value.hex()) if isinstance(value, datetime.datetime): datetime_literal = Literal.string( - (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat( + sep=" " + ) ) return TimeStrToTime(this=datetime_literal) if isinstance(value, datetime.date): date_literal = Literal.string(value.strftime("%Y-%m-%d")) return DateStrToDate(this=date_literal) if isinstance(value, tuple): + if hasattr(value, "_fields"): + return Struct( + expressions=[ + PropertyEQ( + this=to_identifier(k), expression=convert(getattr(value, k), copy=copy) + ) + for k in value._fields + ] + ) return Tuple(expressions=[convert(v, copy=copy) for v in value]) if isinstance(value, list): return Array(expressions=[convert(v, copy=copy) for v in value]) @@ -6795,6 +7031,13 @@ def convert(value: t.Any, copy: bool = False) -> Expression: keys=Array(expressions=[convert(k, copy=copy) for k in value]), values=Array(expressions=[convert(v, copy=copy) for v in value.values()]), ) + if hasattr(value, "__dict__"): + return Struct( + expressions=[ + PropertyEQ(this=to_identifier(k), expression=convert(v, copy=copy)) + for k, v in value.__dict__.items() + ] + ) raise ValueError(f"Cannot convert {value}") @@ -6802,7 +7045,7 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) - """ Replace children of an expression with the result of a lambda fun(child) -> exp. """ - for k, v in expression.args.items(): + for k, v in tuple(expression.args.items()): is_list_arg = type(v) is list child_nodes = v if is_list_arg else [v] @@ -6812,12 +7055,36 @@ def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) - if isinstance(cn, Expression): for child_node in ensure_collection(fun(cn, *args, **kwargs)): new_child_nodes.append(child_node) - child_node.parent = expression - child_node.arg_key = k else: new_child_nodes.append(cn) - expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) + expression.set(k, new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)) + + +def replace_tree( + expression: Expression, + fun: t.Callable, + prune: t.Optional[t.Callable[[Expression], bool]] = None, +) -> Expression: + """ + Replace an entire tree with the result of function calls on each node. + + This will be traversed in reverse dfs, so leaves first. + If new nodes are created as a result of function calls, they will also be traversed. + """ + stack = list(expression.dfs(prune=prune)) + + while stack: + node = stack.pop() + new_node = fun(node) + + if new_node is not node: + node.replace(new_node) + + if isinstance(new_node, Expression): + stack.append(new_node) + + return new_node def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: @@ -6936,7 +7203,7 @@ def replace_tables( return table return node - return expression.transform(_replace_tables, copy=copy) + return expression.transform(_replace_tables, copy=copy) # type: ignore def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: @@ -6961,8 +7228,8 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: def _replace_placeholders(node: Expression, args, **kwargs) -> Expression: if isinstance(node, Placeholder): - if node.name: - new_name = kwargs.get(node.name) + if node.this: + new_name = kwargs.get(node.this) if new_name is not None: return convert(new_name) else: @@ -7193,3 +7460,15 @@ def null() -> Null: Returns a Null expression. """ return Null() + + +NONNULL_CONSTANTS = ( + Literal, + Boolean, +) + +CONSTANTS = ( + Literal, + Boolean, + Null, +) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index e6f5c4b..76d9b5d 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -46,9 +46,11 @@ class Generator(metaclass=_Generator): 'safe': Only quote identifiers that are case insensitive. normalize: Whether to normalize identifiers to lowercase. Default: False. - pad: The pad size in a formatted string. + pad: The pad size in a formatted string. For example, this affects the indentation of + a projection in a query, relative to its nesting level. Default: 2. - indent: The indentation size in a formatted string. + indent: The indentation size in a formatted string. For example, this affects the + indentation of subqueries and filters under a `WHERE` clause. Default: 2. normalize_functions: How to normalize function names. Possible values are: "upper" or True (default): Convert names to uppercase. @@ -73,6 +75,7 @@ class Generator(metaclass=_Generator): TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { **JSON_PATH_PART_TRANSFORMS, exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", + exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}", exp.CaseSpecificColumnConstraint: lambda _, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", @@ -83,15 +86,15 @@ class Generator(metaclass=_Generator): exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", exp.CopyGrantsProperty: lambda *_: "COPY GRANTS", - exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) - ), exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", + exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}", exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), exp.ExternalProperty: lambda *_: "EXTERNAL", + exp.GlobalProperty: lambda *_: "GLOBAL", exp.HeapProperty: lambda *_: "HEAP", + exp.IcebergProperty: lambda *_: "ICEBERG", exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", @@ -123,6 +126,7 @@ class Generator(metaclass=_Generator): exp.SetConfigProperty: lambda self, e: self.sql(e, "this"), exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET", exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", + exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}", exp.SqlReadWriteProperty: lambda _, e: e.name, exp.SqlSecurityProperty: lambda _, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", @@ -130,13 +134,17 @@ class Generator(metaclass=_Generator): exp.TemporaryProperty: lambda *_: "TEMPORARY", exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression), + exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}", exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions), exp.TransientProperty: lambda *_: "TRANSIENT", exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE", + exp.UnloggedProperty: lambda *_: "UNLOGGED", exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), + exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}", exp.VolatileProperty: lambda *_: "VOLATILE", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", + exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}", } # Whether null ordering is supported in order by @@ -321,6 +329,9 @@ class Generator(metaclass=_Generator): # Whether any(f(x) for x in array) can be implemented by this dialect CAN_IMPLEMENT_ARRAY_ANY = False + # Whether the function TO_NUMBER is supported + SUPPORTS_TO_NUMBER = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -350,6 +361,18 @@ class Generator(metaclass=_Generator): "YEARS": "YEAR", } + AFTER_HAVING_MODIFIER_TRANSFORMS = { + "cluster": lambda self, e: self.sql(e, "cluster"), + "distribute": lambda self, e: self.sql(e, "distribute"), + "qualify": lambda self, e: self.sql(e, "qualify"), + "sort": lambda self, e: self.sql(e, "sort"), + "windows": lambda self, e: ( + self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True) + if e.args.get("windows") + else "" + ), + } + TOKEN_MAPPING: t.Dict[TokenType, str] = {} STRUCT_DELIMITER = ("<", ">") @@ -361,6 +384,7 @@ class Generator(metaclass=_Generator): exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA, + exp.BackupProperty: exp.Properties.Location.POST_SCHEMA, exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, exp.ChecksumProperty: exp.Properties.Location.POST_NAME, @@ -380,8 +404,10 @@ class Generator(metaclass=_Generator): exp.FallbackProperty: exp.Properties.Location.POST_NAME, exp.FileFormatProperty: exp.Properties.Location.POST_WITH, exp.FreespaceProperty: exp.Properties.Location.POST_NAME, + exp.GlobalProperty: exp.Properties.Location.POST_CREATE, exp.HeapProperty: exp.Properties.Location.POST_WITH, exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA, + exp.IcebergProperty: exp.Properties.Location.POST_CREATE, exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA, exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, exp.JournalProperty: exp.Properties.Location.POST_NAME, @@ -414,6 +440,8 @@ class Generator(metaclass=_Generator): exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA, + exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION, + exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, @@ -423,6 +451,8 @@ class Generator(metaclass=_Generator): exp.TransientProperty: exp.Properties.Location.POST_CREATE, exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA, exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, + exp.UnloggedProperty: exp.Properties.Location.POST_CREATE, + exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, @@ -441,6 +471,7 @@ class Generator(metaclass=_Generator): exp.Insert, exp.Join, exp.Select, + exp.Union, exp.Update, exp.Where, exp.With, @@ -626,7 +657,7 @@ class Generator(metaclass=_Generator): if isinstance(expression, self.WITH_SEPARATED_COMMENTS): return ( f"{self.sep()}{comments_sql}{sql}" - if sql[0].isspace() + if not sql or sql[0].isspace() else f"{comments_sql}{self.sep()}{sql}" ) @@ -869,7 +900,9 @@ class Generator(metaclass=_Generator): this = f" {this}" if this else "" index_type = expression.args.get("index_type") index_type = f" USING {index_type}" if index_type else "" - return f"UNIQUE{this}{index_type}" + on_conflict = self.sql(expression, "on_conflict") + on_conflict = f" {on_conflict}" if on_conflict else "" + return f"UNIQUE{this}{index_type}{on_conflict}" def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: return self.sql(expression, "this") @@ -961,6 +994,31 @@ class Generator(metaclass=_Generator): expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}" return self.prepend_ctes(expression, expression_sql) + def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str: + start = self.sql(expression, "start") + start = f"START WITH {start}" if start else "" + increment = self.sql(expression, "increment") + increment = f" INCREMENT BY {increment}" if increment else "" + minvalue = self.sql(expression, "minvalue") + minvalue = f" MINVALUE {minvalue}" if minvalue else "" + maxvalue = self.sql(expression, "maxvalue") + maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" + owned = self.sql(expression, "owned") + owned = f" OWNED BY {owned}" if owned else "" + + cache = expression.args.get("cache") + if cache is None: + cache_str = "" + elif cache is True: + cache_str = " CACHE" + else: + cache_str = f" CACHE {cache}" + + options = self.expressions(expression, key="options", flat=True, sep=" ") + options = f" {options}" if options else "" + + return f"{start}{increment}{minvalue}{maxvalue}{cache_str}{options}{owned}".lstrip() + def clone_sql(self, expression: exp.Clone) -> str: this = self.sql(expression, "this") shallow = "SHALLOW " if expression.args.get("shallow") else "" @@ -968,8 +1026,9 @@ class Generator(metaclass=_Generator): return f"{shallow}{keyword} {this}" def describe_sql(self, expression: exp.Describe) -> str: - extended = " EXTENDED" if expression.args.get("extended") else "" - return f"DESCRIBE{extended} {self.sql(expression, 'this')}" + style = expression.args.get("style") + style = f" {style}" if style else "" + return f"DESCRIBE{style} {self.sql(expression, 'this')}" def heredoc_sql(self, expression: exp.Heredoc) -> str: tag = self.sql(expression, "tag") @@ -993,7 +1052,14 @@ class Generator(metaclass=_Generator): def cte_sql(self, expression: exp.CTE) -> str: alias = self.sql(expression, "alias") - return f"{alias} AS {self.wrap(expression)}" + + materialized = expression.args.get("materialized") + if materialized is False: + materialized = "NOT MATERIALIZED " + elif materialized: + materialized = "MATERIALIZED " + + return f"{alias} AS {materialized or ''}{self.wrap(expression)}" def tablealias_sql(self, expression: exp.TableAlias) -> str: alias = self.sql(expression, "this") @@ -1044,7 +1110,7 @@ class Generator(metaclass=_Generator): return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}" def rawstring_sql(self, expression: exp.RawString) -> str: - string = self.escape_str(expression.this.replace("\\", "\\\\")) + string = self.escape_str(expression.this.replace("\\", "\\\\"), escape_backslash=False) return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}" def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: @@ -1114,6 +1180,8 @@ class Generator(metaclass=_Generator): def drop_sql(self, expression: exp.Drop) -> str: this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + expressions = f" ({expressions})" if expressions else "" kind = expression.args["kind"] exists_sql = " IF EXISTS " if expression.args.get("exists") else " " temporary = " TEMPORARY" if expression.args.get("temporary") else "" @@ -1121,15 +1189,10 @@ class Generator(metaclass=_Generator): cascade = " CASCADE" if expression.args.get("cascade") else "" constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" purge = " PURGE" if expression.args.get("purge") else "" - return ( - f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}" - ) + return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{expressions}{cascade}{constraints}{purge}" def except_sql(self, expression: exp.Except) -> str: - return self.prepend_ctes( - expression, - self.set_operation(expression, self.except_op(expression)), - ) + return self.set_operations(expression) def except_op(self, expression: exp.Except) -> str: return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}" @@ -1163,17 +1226,9 @@ class Generator(metaclass=_Generator): return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */" - def index_sql(self, expression: exp.Index) -> str: - unique = "UNIQUE " if expression.args.get("unique") else "" - primary = "PRIMARY " if expression.args.get("primary") else "" - amp = "AMP " if expression.args.get("amp") else "" - name = self.sql(expression, "this") - name = f"{name} " if name else "" - table = self.sql(expression, "table") - table = f"{self.INDEX_ON} {table}" if table else "" + def indexparameters_sql(self, expression: exp.IndexParameters) -> str: using = self.sql(expression, "using") using = f" USING {using}" if using else "" - index = "INDEX " if not table else "" columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" partition_by = self.expressions(expression, key="partition_by", flat=True) @@ -1182,7 +1237,26 @@ class Generator(metaclass=_Generator): include = self.expressions(expression, key="include", flat=True) if include: include = f" INCLUDE ({include})" - return f"{unique}{primary}{amp}{index}{name}{table}{using}{columns}{include}{partition_by}{where}" + with_storage = self.expressions(expression, key="with_storage", flat=True) + with_storage = f" WITH ({with_storage})" if with_storage else "" + tablespace = self.sql(expression, "tablespace") + tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else "" + + return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}" + + def index_sql(self, expression: exp.Index) -> str: + unique = "UNIQUE " if expression.args.get("unique") else "" + primary = "PRIMARY " if expression.args.get("primary") else "" + amp = "AMP " if expression.args.get("amp") else "" + name = self.sql(expression, "this") + name = f"{name} " if name else "" + table = self.sql(expression, "table") + table = f"{self.INDEX_ON} {table}" if table else "" + + index = "INDEX " if not table else "" + + params = self.sql(expression, "params") + return f"{unique}{primary}{amp}{index}{name}{table}{params}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name @@ -1371,15 +1445,9 @@ class Generator(metaclass=_Generator): no = " NO" if no else "" concurrent = expression.args.get("concurrent") concurrent = " CONCURRENT" if concurrent else "" - - for_ = "" - if expression.args.get("for_all"): - for_ = " FOR ALL" - elif expression.args.get("for_insert"): - for_ = " FOR INSERT" - elif expression.args.get("for_none"): - for_ = " FOR NONE" - return f"WITH{no}{concurrent} ISOLATED LOADING{for_}" + target = self.sql(expression, "target") + target = f" {target}" if target else "" + return f"WITH{no}{concurrent} ISOLATED LOADING{target}" def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str: if isinstance(expression.this, list): @@ -1437,6 +1505,7 @@ class Generator(metaclass=_Generator): return f"{sql})" def insert_sql(self, expression: exp.Insert) -> str: + hint = self.sql(expression, "hint") overwrite = expression.args.get("overwrite") if isinstance(expression.this, exp.Directory): @@ -1447,7 +1516,9 @@ class Generator(metaclass=_Generator): alternative = expression.args.get("alternative") alternative = f" OR {alternative}" if alternative else "" ignore = " IGNORE" if expression.args.get("ignore") else "" - + is_function = expression.args.get("is_function") + if is_function: + this = f"{this} FUNCTION" this = f"{this} {self.sql(expression, 'this')}" exists = " IF EXISTS" if expression.args.get("exists") else "" @@ -1457,23 +1528,21 @@ class Generator(metaclass=_Generator): where = self.sql(expression, "where") where = f"{self.sep()}REPLACE WHERE {where}" if where else "" expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" - conflict = self.sql(expression, "conflict") + on_conflict = self.sql(expression, "conflict") + on_conflict = f" {on_conflict}" if on_conflict else "" by_name = " BY NAME" if expression.args.get("by_name") else "" returning = self.sql(expression, "returning") if self.RETURNING_END: - expression_sql = f"{expression_sql}{conflict}{returning}" + expression_sql = f"{expression_sql}{on_conflict}{returning}" else: - expression_sql = f"{returning}{expression_sql}{conflict}" + expression_sql = f"{returning}{expression_sql}{on_conflict}" - sql = f"INSERT{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}" + sql = f"INSERT{hint}{alternative}{ignore}{this}{by_name}{exists}{partition_sql}{where}{expression_sql}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: - return self.prepend_ctes( - expression, - self.set_operation(expression, self.intersect_op(expression)), - ) + return self.set_operations(expression) def intersect_op(self, expression: exp.Intersect) -> str: return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}" @@ -1496,33 +1565,36 @@ class Generator(metaclass=_Generator): def onconflict_sql(self, expression: exp.OnConflict) -> str: conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" + constraint = self.sql(expression, "constraint") - if constraint: - constraint = f"ON CONSTRAINT {constraint}" - key = self.expressions(expression, key="key", flat=True) - do = "" if expression.args.get("duplicate") else " DO " - nothing = "NOTHING" if expression.args.get("nothing") else "" + constraint = f" ON CONSTRAINT {constraint}" if constraint else "" + + conflict_keys = self.expressions(expression, key="conflict_keys", flat=True) + conflict_keys = f"({conflict_keys}) " if conflict_keys else " " + action = self.sql(expression, "action") + expressions = self.expressions(expression, flat=True) - set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else "" if expressions: - expressions = f"UPDATE {set_keyword}{expressions}" - return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}" + set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else "" + expressions = f" {set_keyword}{expressions}" + + return f"{conflict}{constraint}{conflict_keys}{action}{expressions}" def returning_sql(self, expression: exp.Returning) -> str: return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}" def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: - fields = expression.args.get("fields") + fields = self.sql(expression, "fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" - escaped = expression.args.get("escaped") + escaped = self.sql(expression, "escaped") escaped = f" ESCAPED BY {escaped}" if escaped else "" - items = expression.args.get("collection_items") + items = self.sql(expression, "collection_items") items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else "" - keys = expression.args.get("map_keys") + keys = self.sql(expression, "map_keys") keys = f" MAP KEYS TERMINATED BY {keys}" if keys else "" - lines = expression.args.get("lines") + lines = self.sql(expression, "lines") lines = f" LINES TERMINATED BY {lines}" if lines else "" - null = expression.args.get("null") + null = self.sql(expression, "null") null = f" NULL DEFINED AS {null}" if null else "" return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" @@ -1563,7 +1635,9 @@ class Generator(metaclass=_Generator): hints = f" {hints}" if hints and self.TABLE_HINTS else "" pivots = self.expressions(expression, key="pivots", sep=" ", flat=True) pivots = f" {pivots}" if pivots else "" - joins = self.expressions(expression, key="joins", sep="", skip_first=True) + joins = self.indent( + self.expressions(expression, key="joins", sep="", flat=True), skip_first=True + ) laterals = self.expressions(expression, key="laterals", sep="") file_format = self.sql(expression, "format") @@ -1673,9 +1747,11 @@ class Generator(metaclass=_Generator): sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}" return self.prepend_ctes(expression, sql) - def values_sql(self, expression: exp.Values) -> str: + def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: + values_as_table = values_as_table and self.VALUES_AS_TABLE + # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example - if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join): + if values_as_table or not expression.find_ancestor(exp.From, exp.Join): args = self.expressions(expression) alias = self.sql(expression, "alias") values = f"VALUES{self.seg('')}{args}" @@ -1769,8 +1845,9 @@ class Generator(metaclass=_Generator): def connect_sql(self, expression: exp.Connect) -> str: start = self.sql(expression, "start") start = self.seg(f"START WITH {start}") if start else "" + nocycle = " NOCYCLE" if expression.args.get("nocycle") else "" connect = self.sql(expression, "connect") - connect = self.seg(f"CONNECT BY {connect}") + connect = self.seg(f"CONNECT BY{nocycle} {connect}") return start + connect def prior_sql(self, expression: exp.Prior) -> str: @@ -1793,6 +1870,8 @@ class Generator(metaclass=_Generator): ) if op ) + match_cond = self.sql(expression, "match_condition") + match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else "" on_sql = self.sql(expression, "on") using = expression.args.get("using") @@ -1816,7 +1895,7 @@ class Generator(metaclass=_Generator): return f", {this_sql}" op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" - return f"{self.seg(op_sql)} {this_sql}{on_sql}" + return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}" def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) @@ -1919,13 +1998,17 @@ class Generator(metaclass=_Generator): text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}" return text - def escape_str(self, text: str) -> str: - text = text.replace(self.dialect.QUOTE_END, self._escaped_quote_end) - if self.dialect.INVERSE_ESCAPE_SEQUENCES: - text = "".join(self.dialect.INVERSE_ESCAPE_SEQUENCES.get(ch, ch) for ch in text) - elif self.pretty: + def escape_str(self, text: str, escape_backslash: bool = True) -> str: + if self.dialect.ESCAPED_SEQUENCES: + to_escaped = self.dialect.ESCAPED_SEQUENCES + text = "".join( + to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text + ) + + if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) - return text + + return text.replace(self.dialect.QUOTE_END, self._escaped_quote_end) def loaddata_sql(self, expression: exp.LoadData) -> str: local = " LOCAL" if expression.args.get("local") else "" @@ -2016,7 +2099,7 @@ class Generator(metaclass=_Generator): self.unsupported( f"'{nulls_sort_change.strip()}' translation not supported with positional ordering" ) - else: + elif not isinstance(expression.this, exp.Rand): null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else "" this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" nulls_sort_change = "" @@ -2059,24 +2142,13 @@ class Generator(metaclass=_Generator): return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: - limit: t.Optional[exp.Fetch | exp.Limit] = expression.args.get("limit") - - # If the limit is generated as TOP, we need to ensure it's not generated twice - with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP + limit = expression.args.get("limit") if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count"))) elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression)) - fetch = isinstance(limit, exp.Fetch) - - offset_limit_modifiers = ( - self.offset_limit_modifiers(expression, fetch, limit) - if with_offset_limit_modifiers - else [] - ) - options = self.expressions(expression, key="options") if options: options = f" OPTION{self.wrap(options)}" @@ -2091,9 +2163,9 @@ class Generator(metaclass=_Generator): self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), - *self.after_having_modifiers(expression), + *[gen(self, expression) for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values()], self.sql(expression, "order"), - *offset_limit_modifiers, + *self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit), *self.after_limit_modifiers(expression), options, sep="", @@ -2110,19 +2182,6 @@ class Generator(metaclass=_Generator): self.sql(limit) if fetch else self.sql(expression, "offset"), ] - def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: - return [ - self.sql(expression, "qualify"), - ( - self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) - if expression.args.get("windows") - else "" - ), - self.sql(expression, "distribute"), - self.sql(expression, "sort"), - self.sql(expression, "cluster"), - ] - def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: locks = self.expressions(expression, key="locks", sep=" ") locks = f" {locks}" if locks else "" @@ -2137,12 +2196,13 @@ class Generator(metaclass=_Generator): distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" kind = self.sql(expression, "kind") + limit = expression.args.get("limit") - top = ( - self.limit_sql(limit, top=True) - if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP - else "" - ) + if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP: + top = self.limit_sql(limit, top=True) + limit.pop() + else: + top = "" expressions = self.expressions(expression) @@ -2220,7 +2280,7 @@ class Generator(metaclass=_Generator): return f"@@{kind}{this}" def placeholder_sql(self, expression: exp.Placeholder) -> str: - return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?" + return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.this else "?" def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: alias = self.sql(expression, "alias") @@ -2236,11 +2296,32 @@ class Generator(metaclass=_Generator): this = self.indent(self.sql(expression, "this")) return f"{self.seg('QUALIFY')}{self.sep()}{this}" + def set_operations(self, expression: exp.Union) -> str: + sqls: t.List[str] = [] + stack: t.List[t.Union[str, exp.Expression]] = [expression] + + while stack: + node = stack.pop() + + if isinstance(node, exp.Union): + stack.append(node.expression) + stack.append( + self.maybe_comment( + getattr(self, f"{node.key}_op")(node), + expression=node.this, + comments=node.comments, + ) + ) + stack.append(node.this) + else: + sqls.append(self.sql(node)) + + this = self.sep().join(sqls) + this = self.query_modifiers(expression, this) + return self.prepend_ctes(expression, this) + def union_sql(self, expression: exp.Union) -> str: - return self.prepend_ctes( - expression, - self.set_operation(expression, self.union_op(expression)), - ) + return self.set_operations(expression) def union_op(self, expression: exp.Union) -> str: kind = " DISTINCT" if self.EXPLICIT_UNION else "" @@ -2345,8 +2426,10 @@ class Generator(metaclass=_Generator): def any_sql(self, expression: exp.Any) -> str: this = self.sql(expression, "this") - if isinstance(expression.this, exp.UNWRAPPED_QUERIES): - this = self.wrap(this) + if isinstance(expression.this, (*exp.UNWRAPPED_QUERIES, exp.Paren)): + if isinstance(expression.this, exp.UNWRAPPED_QUERIES): + this = self.wrap(this) + return f"ANY{this}" return f"ANY {this}" def exists_sql(self, expression: exp.Exists) -> str: @@ -2632,13 +2715,8 @@ class Generator(metaclass=_Generator): return self.func(self.sql(expression, "this"), *expression.expressions) def paren_sql(self, expression: exp.Paren) -> str: - if isinstance(expression.unnest(), exp.Select): - sql = self.wrap(expression) - else: - sql = self.seg(self.indent(self.sql(expression, "this")), sep="") - sql = f"({sql}{self.seg(')', sep='')}" - - return self.prepend_ctes(expression, sql) + sql = self.seg(self.indent(self.sql(expression, "this")), sep="") + return f"({sql}{self.seg(')', sep='')}" def neg_sql(self, expression: exp.Neg) -> str: # This makes sure we don't convert "- - 5" to "--5", which is a comment @@ -2686,23 +2764,55 @@ class Generator(metaclass=_Generator): def add_sql(self, expression: exp.Add) -> str: return self.binary(expression, "+") - def and_sql(self, expression: exp.And) -> str: - return self.connector_sql(expression, "AND") + def and_sql( + self, expression: exp.And, stack: t.Optional[t.List[str | exp.Expression]] = None + ) -> str: + return self.connector_sql(expression, "AND", stack) - def xor_sql(self, expression: exp.Xor) -> str: - return self.connector_sql(expression, "XOR") + def or_sql( + self, expression: exp.Or, stack: t.Optional[t.List[str | exp.Expression]] = None + ) -> str: + return self.connector_sql(expression, "OR", stack) - def connector_sql(self, expression: exp.Connector, op: str) -> str: - if not self.pretty: - return self.binary(expression, op) + def xor_sql( + self, expression: exp.Xor, stack: t.Optional[t.List[str | exp.Expression]] = None + ) -> str: + return self.connector_sql(expression, "XOR", stack) - sqls = tuple( - self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e) - for i, e in enumerate(expression.flatten(unnest=False)) - ) + def connector_sql( + self, + expression: exp.Connector, + op: str, + stack: t.Optional[t.List[str | exp.Expression]] = None, + ) -> str: + if stack is not None: + if expression.expressions: + stack.append(self.expressions(expression, sep=f" {op} ")) + else: + stack.append(expression.right) + if expression.comments: + for comment in expression.comments: + op += f" /*{self.pad_comment(comment)}*/" + stack.extend((op, expression.left)) + return op + + stack = [expression] + sqls: t.List[str] = [] + ops = set() + + while stack: + node = stack.pop() + if isinstance(node, exp.Connector): + ops.add(getattr(self, f"{node.key}_sql")(node, stack)) + else: + sql = self.sql(node) + if sqls and sqls[-1] in ops: + sqls[-1] += f" {sql}" + else: + sqls.append(sql) - sep = "\n" if self.text_width(sqls) > self.max_text_width else " " - return f"{sep}{op} ".join(sqls) + sep = "\n" if self.pretty and self.text_width(sqls) > self.max_text_width else " " + return sep.join(sqls) def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: return self.binary(expression, "&") @@ -2727,7 +2837,9 @@ class Generator(metaclass=_Generator): format_sql = f" FORMAT {format_sql}" if format_sql else "" to_sql = self.sql(expression, "to") to_sql = f" {to_sql}" if to_sql else "" - return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})" + action = self.sql(expression, "action") + action = f" {action}" if action else "" + return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql}{action})" def currentdate_sql(self, expression: exp.CurrentDate) -> str: zone = self.sql(expression, "this") @@ -2817,7 +2929,7 @@ class Generator(metaclass=_Generator): # Remove db from tables expression = expression.transform( lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n - ) + ).assert_is(exp.RenameTable) this = self.sql(expression, "this") return f"RENAME TO {this}" @@ -2889,30 +3001,6 @@ class Generator(metaclass=_Generator): kind = "MAX" if expression.args.get("max") else "MIN" return f"{this_sql} HAVING {kind} {expression_sql}" - def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str: - if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"): - # The first modifier here will be the one closest to the AggFunc's arg - mods = sorted( - expression.find_all(exp.HavingMax, exp.Order, exp.Limit), - key=lambda x: 0 - if isinstance(x, exp.HavingMax) - else (1 if isinstance(x, exp.Order) else 2), - ) - - if mods: - mod = mods[0] - this = expression.__class__(this=mod.this.copy()) - this.meta["inline"] = True - mod.this.replace(this) - return self.sql(expression.this) - - agg_func = expression.find(exp.AggFunc) - - if agg_func: - return self.sql(agg_func)[:-1] + f" {text})" - - return f"{self.sql(expression, 'this')} {text}" - def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( @@ -2933,9 +3021,7 @@ class Generator(metaclass=_Generator): r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0))) if self.dialect.TYPED_DIVISION and not expression.args.get("typed"): - if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type( - *exp.DataType.FLOAT_TYPES - ): + if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type(*exp.DataType.REAL_TYPES): l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE)) elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"): @@ -3019,9 +3105,6 @@ class Generator(metaclass=_Generator): def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str: return self.binary(expression, "IS DISTINCT FROM") - def or_sql(self, expression: exp.Or) -> str: - return self.connector_sql(expression, "OR") - def slice_sql(self, expression: exp.Slice) -> str: return self.binary(expression, ":") @@ -3035,8 +3118,13 @@ class Generator(metaclass=_Generator): this = expression.this expr = expression.expression - if not self.dialect.LOG_BASE_FIRST: + if self.dialect.LOG_BASE_FIRST is False: this, expr = expr, this + elif self.dialect.LOG_BASE_FIRST is None and expr: + if this.name in ("2", "10"): + return self.func(f"LOG{this.name}", expr) + + self.unsupported(f"Unsupported logarithm with base {self.sql(this)}") return self.func("LOG", this, expr) @@ -3088,11 +3176,16 @@ class Generator(metaclass=_Generator): def text_width(self, args: t.Iterable) -> int: return sum(len(arg) for arg in args) - def format_time(self, expression: exp.Expression) -> t.Optional[str]: + def format_time( + self, + expression: exp.Expression, + inverse_time_mapping: t.Optional[t.Dict[str, str]] = None, + inverse_time_trie: t.Optional[t.Dict] = None, + ) -> t.Optional[str]: return format_time( self.sql(expression, "format"), - self.dialect.INVERSE_TIME_MAPPING, - self.dialect.INVERSE_TIME_TRIE, + inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING, + inverse_time_trie or self.dialect.INVERSE_TIME_TRIE, ) def expressions( @@ -3117,8 +3210,11 @@ class Generator(metaclass=_Generator): num_sqls = len(expressions) # These are calculated once in case we have the leading_comma / pretty option set, correspondingly - pad = " " * self.pad - stripped_sep = sep.strip() + if self.pretty: + if self.leading_comma: + pad = " " * len(sep) + else: + stripped_sep = sep.strip() result_sqls = [] for i, e in enumerate(expressions): @@ -3154,13 +3250,6 @@ class Generator(metaclass=_Generator): self.unsupported(f"Unsupported property {expression.__class__.__name__}") return f"{property_name} {self.sql(expression, 'this')}" - def set_operation(self, expression: exp.Union, op: str) -> str: - this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments) - op = self.seg(op) - return self.query_modifiers( - expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" - ) - def tag_sql(self, expression: exp.Tag) -> str: return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}" @@ -3227,6 +3316,18 @@ class Generator(metaclass=_Generator): return self.sql(exp.cast(expression.this, "text")) + def tonumber_sql(self, expression: exp.ToNumber) -> str: + if not self.SUPPORTS_TO_NUMBER: + self.unsupported("Unsupported TO_NUMBER function") + return self.sql(exp.cast(expression.this, "double")) + + fmt = expression.args.get("format") + if not fmt: + self.unsupported("Conversion format is required for TO_NUMBER") + return self.sql(exp.cast(expression.this, "double")) + + return self.func("TO_NUMBER", expression.this, fmt) + def dictproperty_sql(self, expression: exp.DictProperty) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind") @@ -3320,11 +3421,11 @@ class Generator(metaclass=_Generator): this = f" {this}" if this else "" index_type = self.sql(expression, "index_type") index_type = f" USING {index_type}" if index_type else "" - schema = self.sql(expression, "schema") - schema = f" {schema}" if schema else "" + expressions = self.expressions(expression, flat=True) + expressions = f" ({expressions})" if expressions else "" options = self.expressions(expression, key="options", sep=" ") options = f" {options}" if options else "" - return f"{kind}{this}{index_type}{schema}{options}" + return f"{kind}{this}{index_type}{expressions}{options}" def nvl2_sql(self, expression: exp.Nvl2) -> str: if self.NVL2_SUPPORTED: @@ -3396,6 +3497,13 @@ class Generator(metaclass=_Generator): return self.sql(exp.cast(this, "time")) + def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str: + this = expression.this + if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP): + return self.sql(this) + + return self.sql(exp.cast(this, "timestamp")) + def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str: this = expression.this time_format = self.format_time(expression) @@ -3430,6 +3538,13 @@ class Generator(metaclass=_Generator): return self.func("LAST_DAY", expression.this) + def dateadd_sql(self, expression: exp.DateAdd) -> str: + from sqlglot.dialects.dialect import unit_to_str + + return self.func( + "DATE_ADD", expression.this, expression.expression, unit_to_str(expression) + ) + def arrayany_sql(self, expression: exp.ArrayAny) -> str: if self.CAN_IMPLEMENT_ARRAY_ANY: filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression) @@ -3445,30 +3560,6 @@ class Generator(metaclass=_Generator): return self.function_fallback_sql(expression) - def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: - this = expression.this - if isinstance(this, exp.JSONPathWildcard): - this = self.json_path_part(this) - return f".{this}" if this else "" - - if exp.SAFE_IDENTIFIER_RE.match(this): - return f".{this}" - - this = self.json_path_part(this) - return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}" - - def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: - this = self.json_path_part(expression.this) - return f"[{this}]" if this else "" - - def _simplify_unless_literal(self, expression: E) -> E: - if not isinstance(expression, exp.Literal): - from sqlglot.optimizer.simplify import simplify - - expression = simplify(expression, dialect=self.dialect) - - return expression - def generateseries_sql(self, expression: exp.GenerateSeries) -> str: expression.set("is_end_exclusive", None) return self.function_fallback_sql(expression) @@ -3477,7 +3568,9 @@ class Generator(metaclass=_Generator): expression.set( "expressions", [ - exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e + exp.alias_(e.expression, e.name if e.this.is_string else e.this) + if isinstance(e, exp.PropertyEQ) + else e for e in expression.expressions ], ) @@ -3553,3 +3646,51 @@ class Generator(metaclass=_Generator): transformed = cast(this=value, to=to, safe=safe) return self.sql(transformed) + + def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: + this = expression.this + if isinstance(this, exp.JSONPathWildcard): + this = self.json_path_part(this) + return f".{this}" if this else "" + + if exp.SAFE_IDENTIFIER_RE.match(this): + return f".{this}" + + this = self.json_path_part(this) + return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}" + + def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: + this = self.json_path_part(expression.this) + return f"[{this}]" if this else "" + + def _simplify_unless_literal(self, expression: E) -> E: + if not isinstance(expression, exp.Literal): + from sqlglot.optimizer.simplify import simplify + + expression = simplify(expression, dialect=self.dialect) + + return expression + + def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str: + if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"): + # The first modifier here will be the one closest to the AggFunc's arg + mods = sorted( + expression.find_all(exp.HavingMax, exp.Order, exp.Limit), + key=lambda x: 0 + if isinstance(x, exp.HavingMax) + else (1 if isinstance(x, exp.Order) else 2), + ) + + if mods: + mod = mods[0] + this = expression.__class__(this=mod.this.copy()) + this.meta["inline"] = True + mod.this.replace(this) + return self.sql(expression.this) + + agg_func = expression.find(exp.AggFunc) + + if agg_func: + return self.sql(agg_func)[:-1] + f" {text})" + + return f"{self.sql(expression, 'this')} {text}" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 0d4547f..0187c51 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -181,7 +181,7 @@ def apply_index_offset( annotate_types(expression) if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: logger.warning("Applying array index offset (%s)", offset) - expression = simplify(exp.Add(this=expression, expression=exp.Literal.number(offset))) + expression = simplify(expression + offset) return [expression] return expressions @@ -204,13 +204,13 @@ def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> The transformed expression. """ while True: - for n, *_ in reversed(tuple(expression.walk())): + for n in reversed(tuple(expression.walk())): n._hash = hash(n) start = hash(expression) expression = func(expression) - for n, *_ in expression.walk(): + for n in expression.walk(): n._hash = None if start == hash(expression): break @@ -317,8 +317,16 @@ def find_new_name(taken: t.Collection[str], base: str) -> str: def is_int(text: str) -> bool: + return is_type(text, int) + + +def is_float(text: str) -> bool: + return is_type(text, float) + + +def is_type(text: str, target_type: t.Type) -> bool: try: - int(text) + target_type(text) return True except ValueError: return False diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index eb428dc..c91bb36 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -28,10 +28,7 @@ class Node: yield self for d in self.downstream: - if isinstance(d, Node): - yield from d.walk() - else: - yield d + yield from d.walk() def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: nodes = {} @@ -71,8 +68,10 @@ def lineage( column: str | exp.Column, sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, - sources: t.Optional[t.Dict[str, str | exp.Query]] = None, + sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, dialect: DialectType = None, + scope: t.Optional[Scope] = None, + trim_selects: bool = True, **kwargs, ) -> Node: """Build the lineage graph for a column of a SQL query. @@ -83,6 +82,8 @@ def lineage( schema: The schema of tables. sources: A mapping of queries which will be used to continue building lineage. dialect: The dialect of input SQL. + scope: A pre-created scope to use instead. + trim_selects: Whether or not to clean up selects by trimming to only relevant columns. **kwargs: Qualification optimizer kwargs. Returns: @@ -99,14 +100,15 @@ def lineage( dialect=dialect, ) - qualified = qualify.qualify( - expression, - dialect=dialect, - schema=schema, - **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore - ) + if not scope: + expression = qualify.qualify( + expression, + dialect=dialect, + schema=schema, + **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore + ) - scope = build_scope(qualified) + scope = build_scope(expression) if not scope: raise SqlglotError("Cannot build lineage, sql must be SELECT") @@ -114,7 +116,7 @@ def lineage( if not any(select.alias_or_name == column for select in scope.expression.selects): raise SqlglotError(f"Cannot find column '{column}' in query.") - return to_node(column, scope, dialect) + return to_node(column, scope, dialect, trim_selects=trim_selects) def to_node( @@ -125,6 +127,7 @@ def to_node( upstream: t.Optional[Node] = None, source_name: t.Optional[str] = None, reference_node_name: t.Optional[str] = None, + trim_selects: bool = True, ) -> Node: source_names = { dt.alias: dt.comments[0].split()[1] @@ -143,6 +146,17 @@ def to_node( ) ) + if isinstance(scope.expression, exp.Subquery): + for source in scope.subquery_scopes: + return to_node( + column, + scope=source, + dialect=dialect, + upstream=upstream, + source_name=source_name, + reference_node_name=reference_node_name, + trim_selects=trim_selects, + ) if isinstance(scope.expression, exp.Union): upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) @@ -170,11 +184,12 @@ def to_node( upstream=upstream, source_name=source_name, reference_node_name=reference_node_name, + trim_selects=trim_selects, ) return upstream - if isinstance(scope.expression, exp.Select): + if trim_selects and isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with # a version that has only the column we care about. # "x", SELECT x, y FROM foo @@ -206,7 +221,13 @@ def to_node( continue for name in subquery.named_selects: - to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) + to_node( + name, + scope=subquery_scope, + dialect=dialect, + upstream=node, + trim_selects=trim_selects, + ) # if the select is a star add all scope sources as downstreams if select.is_star: @@ -237,6 +258,7 @@ def to_node( upstream=node, source_name=source_names.get(table) or source_name, reference_node_name=selected_node.name if selected_node else None, + trim_selects=trim_selects, ) else: # The source is not a scope - we've reached the end of the line. At this point, if a source is not found diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 81b1ee6..c85ef1c 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -168,8 +168,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Exp, exp.Ln, exp.Log, - exp.Log2, - exp.Log10, exp.Pow, exp.Quantile, exp.Round, @@ -266,26 +264,30 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Dot: lambda self, e: self._annotate_dot(e), exp.Explode: lambda self, e: self._annotate_explode(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), + exp.GenerateDateArray: lambda self, e: self._annotate_with_type( + e, exp.DataType.build("ARRAY<DATE>") + ), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Literal: lambda self, e: self._annotate_literal(e), - exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), + exp.Map: lambda self, e: self._annotate_map(e), exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), - exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), + exp.Struct: lambda self, e: self._annotate_struct(e), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.Timestamp: lambda self, e: self._annotate_with_type( e, exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, ), + exp.ToMap: lambda self, e: self._annotate_to_map(e), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.Unnest: lambda self, e: self._annotate_unnest(e), - exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), + exp.VarMap: lambda self, e: self._annotate_map(e), } NESTED_TYPES = { @@ -358,6 +360,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): if isinstance(source.expression, exp.Lateral): if isinstance(source.expression.this, exp.Explode): values = [source.expression.this.this] + elif isinstance(source.expression, exp.Unnest): + values = [source.expression] else: values = source.expression.expressions[0].expressions @@ -408,7 +412,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ) def _annotate_args(self, expression: E) -> E: - for _, value in expression.iter_expressions(): + for value in expression.iter_expressions(): self._maybe_annotate(value) return expression @@ -425,23 +429,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator): if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): return exp.DataType.Type.UNKNOWN - if type1_value in self.NESTED_TYPES: - return type1 - if type2_value in self.NESTED_TYPES: - return type2 - - return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore - - # Note: the following "no_type_check" decorators were added because mypy was yelling due - # to assigning Type values to expression.type (since its getter returns Optional[DataType]). - # This is a known mypy issue: https://github.com/python/mypy/issues/3004 + return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value - @t.no_type_check def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) left, right = expression.left, expression.right - left_type, right_type = left.type.this, right.type.this + left_type, right_type = left.type.this, right.type.this # type: ignore if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -462,7 +456,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression - @t.no_type_check def _annotate_unary(self, expression: E) -> E: self._annotate_args(expression) @@ -473,7 +466,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression - @t.no_type_check def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: if expression.is_string: self._set_type(expression, exp.DataType.Type.VARCHAR) @@ -484,33 +476,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression - @t.no_type_check def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: self._set_type(expression, target_type) return self._annotate_args(expression) @t.no_type_check - def _annotate_struct_value( - self, expression: exp.Expression - ) -> t.Optional[exp.DataType] | exp.ColumnDef: - alias = expression.args.get("alias") - if alias: - return exp.ColumnDef(this=alias.copy(), kind=expression.type) - - # Case: key = value or key := value - if expression.expression: - return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type) - - return expression.type - - @t.no_type_check def _annotate_by_args( self, expression: E, *args: str, promote: bool = False, array: bool = False, - struct: bool = False, ) -> E: self._annotate_args(expression) @@ -546,16 +522,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ), ) - if struct: - self._set_type( - expression, - exp.DataType( - this=exp.DataType.Type.STRUCT, - expressions=[self._annotate_struct_value(expr) for expr in expressions], - nested=True, - ), - ) - return expression def _annotate_timeunit( @@ -605,6 +571,10 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, exp.DataType.Type.BIGINT) else: self._set_type(expression, self._maybe_coerce(left_type, right_type)) + if expression.type and expression.type.this not in exp.DataType.REAL_TYPES: + self._set_type( + expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE) + ) return expression @@ -631,3 +601,68 @@ class TypeAnnotator(metaclass=_TypeAnnotator): child = seq_get(expression.expressions, 0) self._set_type(expression, child and seq_get(child.type.expressions, 0)) return expression + + def _annotate_struct_value( + self, expression: exp.Expression + ) -> t.Optional[exp.DataType] | exp.ColumnDef: + alias = expression.args.get("alias") + if alias: + return exp.ColumnDef(this=alias.copy(), kind=expression.type) + + # Case: key = value or key := value + if expression.expression: + return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type) + + return expression.type + + def _annotate_struct(self, expression: exp.Struct) -> exp.Struct: + self._annotate_args(expression) + self._set_type( + expression, + exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[self._annotate_struct_value(expr) for expr in expression.expressions], + nested=True, + ), + ) + return expression + + @t.overload + def _annotate_map(self, expression: exp.Map) -> exp.Map: ... + + @t.overload + def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ... + + def _annotate_map(self, expression): + self._annotate_args(expression) + + keys = expression.args.get("keys") + values = expression.args.get("values") + + map_type = exp.DataType(this=exp.DataType.Type.MAP) + if isinstance(keys, exp.Array) and isinstance(values, exp.Array): + key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN + value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN + + if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN: + map_type.set("expressions", [key_type, value_type]) + map_type.set("nested", True) + + self._set_type(expression, map_type) + return expression + + def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap: + self._annotate_args(expression) + + map_type = exp.DataType(this=exp.DataType.Type.MAP) + arg = expression.this + if arg.is_type(exp.DataType.Type.STRUCT): + for coldef in arg.type.expressions: + kind = coldef.kind + if kind != exp.DataType.Type.UNKNOWN: + map_type.set("expressions", [exp.DataType.build("varchar"), kind]) + map_type.set("nested", True) + break + + self._set_type(expression, map_type) + return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 0aa8134..17a5089 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -16,16 +16,17 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: Args: expression: The expression to canonicalize. """ - exp.replace_children(expression, canonicalize) - expression = add_text_to_concat(expression) - expression = replace_date_funcs(expression) - expression = coerce_type(expression) - expression = remove_redundant_casts(expression) - expression = ensure_bools(expression, _replace_int_predicate) - expression = remove_ascending_order(expression) + def _canonicalize(expression: exp.Expression) -> exp.Expression: + expression = add_text_to_concat(expression) + expression = replace_date_funcs(expression) + expression = coerce_type(expression) + expression = remove_redundant_casts(expression) + expression = ensure_bools(expression, _replace_int_predicate) + expression = remove_ascending_order(expression) + return expression - return expression + return exp.replace_tree(expression, _canonicalize) def add_text_to_concat(node: exp.Expression) -> exp.Expression: @@ -35,7 +36,11 @@ def add_text_to_concat(node: exp.Expression) -> exp.Expression: def replace_date_funcs(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Date) and not node.expressions and not node.args.get("zone"): + if ( + isinstance(node, (exp.Date, exp.TsOrDsToDate)) + and not node.expressions + and not node.args.get("zone") + ): return exp.cast(node.this, to=exp.DataType.Type.DATE) if isinstance(node, exp.Timestamp) and not node.expression: if not node.type: @@ -121,15 +126,11 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: a = _coerce_timeunit_arg(a, b.unit) if ( a.type - and a.type.this == exp.DataType.Type.DATE + and a.type.this in exp.DataType.TEMPORAL_TYPES and b.type - and b.type.this - not in ( - exp.DataType.Type.DATE, - exp.DataType.Type.INTERVAL, - ) + and b.type.this in exp.DataType.TEXT_TYPES ): - _replace_cast(b, exp.DataType.Type.DATE) + _replace_cast(b, exp.DataType.Type.DATETIME) def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: @@ -169,7 +170,7 @@ def _replace_cast(node: exp.Expression, to: exp.DataType.Type) -> None: # with y as (select true as x) select x = 0 FROM y -- illegal presto query def _replace_int_predicate(expression: exp.Expression) -> None: if isinstance(expression, exp.Coalesce): - for _, child in expression.iter_expressions(): + for child in expression.iter_expressions(): _replace_int_predicate(child) elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: expression.replace(expression.neq(0)) diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py index 6f1865c..d2e876c 100644 --- a/sqlglot/optimizer/eliminate_ctes.py +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -32,7 +32,7 @@ def eliminate_ctes(expression): cte_node.pop() # Pop the entire WITH clause if this is the last CTE - if len(with_node.expressions) <= 0: + if with_node and len(with_node.expressions) <= 0: with_node.pop() # Decrement the ref count for all sources this CTE selects from diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index ea148cc..603f5df 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -214,6 +214,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() and not _is_recursive() + and not (inner_select.args.get("order") and outer_scope.is_union) ) diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 6bf877b..49b6c98 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -28,7 +28,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): + for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): continue diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index f2a0990..eb84c00 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -53,10 +53,8 @@ def normalize_identifiers(expression, dialect=None): if isinstance(expression, str): expression = exp.parse_identifier(expression, dialect=dialect) - def _normalize(node: E) -> E: + for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")): if not node.meta.get("case_sensitive"): - exp.replace_children(node, _normalize) - node = dialect.normalize_identifier(node) - return node + dialect.normalize_identifier(node) - return _normalize(expression) + return expression diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 1c96e95..c82b8aa 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -82,13 +82,13 @@ def optimize( **kwargs, } - expression = exp.maybe_parse(expression, dialect=dialect, copy=True) + optimized = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames rule_kwargs = { param: possible_kwargs[param] for param in rule_params if param in possible_kwargs } - expression = rule(expression, **rule_kwargs) + optimized = rule(optimized, **rule_kwargs) - return t.cast(exp.Expression, expression) + return optimized diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index 12c3b89..18c9e83 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): pushdown_dnf(predicates, sources, scope_ref_count) -def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None): +def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): """ If the predicates are in CNF like form, we can simply replace each block in the parent. """ join_index = join_index or {} for predicate in predicates: - for node in nodes_for_predicate(predicate, scope, scope_ref_count).values(): + for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): if isinstance(node, exp.Join): name = node.alias_or_name predicate_tables = exp.column_table_names(predicate, name) @@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None): node.where(inner_predicate, copy=False) -def pushdown_dnf(predicates, scope, scope_ref_count): +def pushdown_dnf(predicates, sources, scope_ref_count): """ If the predicates are in DNF form, we can only push down conditions that are in all blocks. Additionally, we can't remove predicates from their original form. @@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count): # pushdown all predicates to their respective nodes for table in sorted(pushdown_tables): for predicate in predicates: - nodes = nodes_for_predicate(predicate, scope, scope_ref_count) + nodes = nodes_for_predicate(predicate, sources, scope_ref_count) if table not in nodes: continue diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 53490bf..d97fd36 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -54,11 +54,15 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) if any(select.is_star for select in right.expression.selects): referenced_columns[right] = parent_selections elif not any(select.is_star for select in left.expression.selects): - referenced_columns[right] = [ - right.expression.selects[i].alias_or_name - for i, select in enumerate(left.expression.selects) - if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections - ] + if scope.expression.args.get("by_name"): + referenced_columns[right] = referenced_columns[left] + else: + referenced_columns[right] = [ + right.expression.selects[i].alias_or_name + for i, select in enumerate(left.expression.selects) + if SELECT_ALL in parent_selections + or select.alias_or_name in parent_selections + ] if isinstance(scope.expression, exp.Select): if remove_unused_selections: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 233ffc9..027c32c 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -209,7 +209,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not node: return - for column, *_ in walk_in_scope(node, prune=lambda node, *_: node.is_star): + for column in walk_in_scope(node, prune=lambda node: node.is_star): if not isinstance(column, exp.Column): continue @@ -306,7 +306,7 @@ def _expand_positional_references( else: select = select.this - if isinstance(select, exp.Literal): + if isinstance(select, exp.CONSTANTS) or select.find(exp.Explode, exp.Unnest): new_nodes.append(node) else: new_nodes.append(select.copy()) @@ -425,7 +425,7 @@ def _expand_stars( raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) - columns = columns or scope.outer_column_list + columns = columns or scope.outer_columns if pseudocolumns: columns = [name for name in columns if name.upper() not in pseudocolumns] @@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: new_selections = [] for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.expression.selects, scope.outer_column_list) + itertools.zip_longest(scope.expression.selects, scope.outer_columns) ): if selection is None: break @@ -544,7 +544,7 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool """Makes sure all identifiers that need to be quoted are quoted.""" return expression.transform( Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False - ) + ) # type: ignore def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 214ac0a..a034bf5 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -56,7 +56,7 @@ def qualify_tables( table.set("catalog", catalog) if not isinstance(expression, exp.Query): - for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)): + for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): if isinstance(node, exp.Table): _qualify(node) @@ -118,11 +118,11 @@ def qualify_tables( for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) else: - for node, parent, _ in scope.walk(): + for node in scope.walk(): if ( isinstance(node, exp.Table) and not node.alias - and isinstance(parent, (exp.From, exp.Join)) + and isinstance(node.parent, (exp.From, exp.Join)) ): # Mutates the table by attaching an alias to it alias(node, node.name, copy=False, table=True) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 443fa6c..073ced2 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -8,7 +8,7 @@ from enum import Enum, auto from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.helper import ensure_collection, find_new_name +from sqlglot.helper import ensure_collection, find_new_name, seq_get logger = logging.getLogger("sqlglot") @@ -38,11 +38,11 @@ class Scope: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source. cte_sources (dict[str, Scope]): Sources from CTES - outer_column_list (list[str]): If this is a derived table or CTE, and the outer query - defines a column list of it's alias of this scope, this is that list of columns. + outer_columns (list[str]): If this is a derived table or CTE, and the outer query + defines a column list for the alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) - The inner query would have `["col1", "col2"]` for its `outer_column_list` + The inner query would have `["col1", "col2"]` for its `outer_columns` parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent subquery_scopes (list[Scope]): List of all child scopes for subqueries @@ -58,7 +58,7 @@ class Scope: self, expression, sources=None, - outer_column_list=None, + outer_columns=None, parent=None, scope_type=ScopeType.ROOT, lateral_sources=None, @@ -70,7 +70,7 @@ class Scope: self.cte_sources = cte_sources or {} self.sources.update(self.lateral_sources) self.sources.update(self.cte_sources) - self.outer_column_list = outer_column_list or [] + self.outer_columns = outer_columns or [] self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] @@ -119,10 +119,11 @@ class Scope: self._raw_columns = [] self._join_hints = [] - for node, parent, _ in self.walk(bfs=False): + for node in self.walk(bfs=False): if node is self.expression: continue - elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + + if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): self._raw_columns.append(node) elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): self._tables.append(node) @@ -132,10 +133,8 @@ class Scope: self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) - elif ( - isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) - and _is_derived_table(node) + elif _is_derived_table(node) and isinstance( + node.parent, (exp.From, exp.Join, exp.Subquery) ): self._derived_tables.append(node) elif isinstance(node, exp.UNWRAPPED_QUERIES): @@ -438,11 +437,21 @@ class Scope: Yields: Scope: scope instances in depth-first-search post-order """ - for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes - ): - yield from child_scope.traverse() - yield self + stack = [self] + result = [] + while stack: + scope = stack.pop() + result.append(scope) + stack.extend( + itertools.chain( + scope.cte_scopes, + scope.union_scopes, + scope.table_scopes, + scope.subquery_scopes, + ) + ) + + yield from reversed(result) def ref_count(self): """ @@ -481,14 +490,28 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) Args: - expression (exp.Expression): expression to traverse + expression: Expression to traverse Returns: - list[Scope]: scope instances + A list of the created scope instances """ - if isinstance(expression, exp.Query) or ( - isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query) - ): + if isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query): + # We ignore the DDL expression and build a scope for its query instead + ddl_with = expression.args.get("with") + expression = expression.expression + + # If the DDL has CTEs attached, we need to add them to the query, or + # prepend them if the query itself already has CTEs attached to it + if ddl_with: + ddl_with.pop() + query_ctes = expression.ctes + if not query_ctes: + expression.set("with", ddl_with) + else: + expression.args["with"].set("recursive", ddl_with.recursive) + expression.args["with"].set("expressions", [*ddl_with.expressions, *query_ctes]) + + if isinstance(expression, exp.Query): return list(_traverse_scope(Scope(expression))) return [] @@ -499,21 +522,21 @@ def build_scope(expression: exp.Expression) -> t.Optional[Scope]: Build a scope tree. Args: - expression (exp.Expression): expression to build the scope tree for + expression: Expression to build the scope tree for. + Returns: - Scope: root scope + The root scope """ - scopes = traverse_scope(expression) - if scopes: - return scopes[-1] - return None + return seq_get(traverse_scope(expression), -1) def _traverse_scope(scope): if isinstance(scope.expression, exp.Select): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): + yield from _traverse_ctes(scope) yield from _traverse_union(scope) + return elif isinstance(scope.expression, exp.Subquery): if scope.is_root: yield from _traverse_select(scope) @@ -523,8 +546,6 @@ def _traverse_scope(scope): yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): yield from _traverse_udtfs(scope) - elif isinstance(scope.expression, exp.DDL): - yield from _traverse_ddl(scope) else: logger.warning( "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) @@ -541,30 +562,38 @@ def _traverse_select(scope): def _traverse_union(scope): - yield from _traverse_ctes(scope) + prev_scope = None + union_scope_stack = [scope] + expression_stack = [scope.expression.right, scope.expression.left] - # The last scope to be yield should be the top most scope - left = None - for left in _traverse_scope( - scope.branch( - scope.expression.left, - outer_column_list=scope.outer_column_list, - scope_type=ScopeType.UNION, - ) - ): - yield left + while expression_stack: + expression = expression_stack.pop() + union_scope = union_scope_stack[-1] - right = None - for right in _traverse_scope( - scope.branch( - scope.expression.right, - outer_column_list=scope.outer_column_list, + new_scope = union_scope.branch( + expression, + outer_columns=union_scope.outer_columns, scope_type=ScopeType.UNION, ) - ): - yield right - scope.union_scopes = [left, right] + if isinstance(expression, exp.Union): + yield from _traverse_ctes(new_scope) + + union_scope_stack.append(new_scope) + expression_stack.extend([expression.right, expression.left]) + continue + + for scope in _traverse_scope(new_scope): + yield scope + + if prev_scope: + union_scope_stack.pop() + union_scope.union_scopes = [prev_scope, scope] + prev_scope = union_scope + + yield union_scope + else: + prev_scope = scope def _traverse_ctes(scope): @@ -588,7 +617,7 @@ def _traverse_ctes(scope): scope.branch( cte.this, cte_sources=sources, - outer_column_list=cte.alias_column_names, + outer_columns=cte.alias_column_names, scope_type=ScopeType.CTE, ) ): @@ -615,7 +644,9 @@ def _is_derived_table(expression: exp.Subquery) -> bool: as it doesn't introduce a new scope. If an alias is present, it shadows all names under the Subquery, so that's one exception to this rule. """ - return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)) + return isinstance(expression, exp.Subquery) and bool( + expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) + ) def _traverse_tables(scope): @@ -681,7 +712,7 @@ def _traverse_tables(scope): scope.branch( expression, lateral_sources=lateral_sources, - outer_column_list=expression.alias_column_names, + outer_columns=expression.alias_column_names, scope_type=scope_type, ) ): @@ -719,13 +750,13 @@ def _traverse_udtfs(scope): sources = {} for expression in expressions: - if isinstance(expression, exp.Subquery) and _is_derived_table(expression): + if _is_derived_table(expression): top = None for child_scope in _traverse_scope( scope.branch( expression, scope_type=ScopeType.DERIVED_TABLE, - outer_column_list=expression.alias_column_names, + outer_columns=expression.alias_column_names, ) ): yield child_scope @@ -738,18 +769,6 @@ def _traverse_udtfs(scope): scope.sources.update(sources) -def _traverse_ddl(scope): - yield from _traverse_ctes(scope) - - query_scope = scope.branch( - scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, sources=scope.sources - ) - query_scope._collect() - query_scope._ctes = scope.ctes + query_scope._ctes - - yield from _traverse_scope(query_scope) - - def walk_in_scope(expression, bfs=True, prune=None): """ Returns a generator object which visits all nodes in the syntrax tree, stopping at @@ -769,23 +788,21 @@ def walk_in_scope(expression, bfs=True, prune=None): # Whenever we set it to True, we exclude a subtree from traversal. crossed_scope_boundary = False - for node, parent, key in expression.walk( - bfs=bfs, prune=lambda *args: crossed_scope_boundary or (prune and prune(*args)) + for node in expression.walk( + bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) ): crossed_scope_boundary = False - yield node, parent, key + yield node if node is expression: continue if ( isinstance(node, exp.CTE) or ( - isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) - and _is_derived_table(node) + isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) + and (_is_derived_table(node) or isinstance(node, exp.UDTF)) ) - or isinstance(node, exp.UDTF) or isinstance(node, exp.UNWRAPPED_QUERIES) ): crossed_scope_boundary = True @@ -812,7 +829,7 @@ def find_all_in_scope(expression, expression_types, bfs=True): Yields: exp.Expression: nodes """ - for expression, *_ in walk_in_scope(expression, bfs=bfs): + for expression in walk_in_scope(expression, bfs=bfs): if isinstance(expression, tuple(ensure_collection(expression_types))): yield expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 2e43d21..d9a0d2b 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -9,19 +9,25 @@ from decimal import Decimal import sqlglot from sqlglot import Dialect, exp -from sqlglot.helper import first, is_iterable, merge_ranges, while_changing +from sqlglot.helper import first, merge_ranges, while_changing from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType DateTruncBinaryTransform = t.Callable[ - [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] + [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] ] # Final means that an expression should not be simplified FINAL = "final" +# Value ranges for byte-sized signed/unsigned integers +TINYINT_MIN = -128 +TINYINT_MAX = 127 +UTINYINT_MIN = 0 +UTINYINT_MAX = 255 + class UnsupportedUnit(Exception): pass @@ -63,14 +69,14 @@ def simplify( group.meta[FINAL] = True for e in expression.selects: - for node, *_ in e.walk(): + for node in e.walk(): if node in groups: e.meta[FINAL] = True break having = expression.args.get("having") if having: - for node, *_ in having.walk(): + for node in having.walk(): if node in groups: having.meta[FINAL] = True break @@ -304,6 +310,8 @@ def _simplify_comparison(expression, left, right, or_=False): r = extract_date(r) if not r: return None + # python won't compare date and datetime, but many engines will upcast + l, r = cast_as_datetime(l), cast_as_datetime(r) for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): @@ -431,7 +439,7 @@ def propagate_constants(expression, root=True): and sqlglot.optimizer.normalize.normalized(expression, dnf=True) ): constant_mapping = {} - for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): + for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): if isinstance(expr, exp.EQ): l, r = expr.left, expr.right @@ -544,7 +552,37 @@ def simplify_literals(expression, root=True): return expression +NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) + + +def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: + if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): + this = _simplify_integer_cast(expr.this) + else: + this = expr.this + + if isinstance(expr, exp.Cast) and this.is_int: + num = int(this.name) + + # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any + # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is + # engine-dependent + if ( + TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES + ) or ( + UTINYINT_MIN <= num <= UTINYINT_MAX + and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES + ): + return this + + return expr + + def _simplify_binary(expression, a, b): + if isinstance(expression, COMPARISONS): + a = _simplify_integer_cast(a) + b = _simplify_integer_cast(b) + if isinstance(expression, exp.Is): if isinstance(b, exp.Not): c = b.this @@ -558,7 +596,7 @@ def _simplify_binary(expression, a, b): return exp.true() if not_ else exp.false() if is_null(a): return exp.false() if not_ else exp.true() - elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): + elif isinstance(expression, NULL_OK): return None elif is_null(a) or is_null(b): return exp.null() @@ -591,17 +629,17 @@ def _simplify_binary(expression, a, b): if boolean: return boolean elif _is_date_literal(a) and isinstance(b, exp.Interval): - a, b = extract_date(a), extract_interval(b) - if a and b: + date, b = extract_date(a), extract_interval(b) + if date and b: if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): - return date_literal(a + b) + return date_literal(date + b, extract_type(a)) if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): - return date_literal(a - b) + return date_literal(date - b, extract_type(a)) elif isinstance(a, exp.Interval) and _is_date_literal(b): - a, b = extract_interval(a), extract_date(b) + a, date = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): - return date_literal(a + b) + return date_literal(a + date, extract_type(b)) elif _is_date_literal(a) and _is_date_literal(b): if isinstance(expression, exp.Predicate): a, b = extract_date(a), extract_date(b) @@ -618,12 +656,16 @@ def simplify_parens(expression): this = expression.this parent = expression.parent + parent_is_predicate = isinstance(parent, exp.Predicate) if not isinstance(this, exp.Select) and ( not isinstance(parent, (exp.Condition, exp.Binary)) or isinstance(parent, exp.Paren) - or not isinstance(this, exp.Binary) - or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) + or ( + not isinstance(this, exp.Binary) + and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) + ) + or (isinstance(this, exp.Predicate) and not parent_is_predicate) or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) @@ -632,24 +674,12 @@ def simplify_parens(expression): return expression -NONNULL_CONSTANTS = ( - exp.Literal, - exp.Boolean, -) - -CONSTANTS = ( - exp.Literal, - exp.Boolean, - exp.Null, -) - - def _is_nonnull_constant(expression: exp.Expression) -> bool: - return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) + return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) def _is_constant(expression: exp.Expression) -> bool: - return isinstance(expression, CONSTANTS) or _is_date_literal(expression) + return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) def simplify_coalesce(expression): @@ -820,45 +850,55 @@ def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Opti return floor, floor + interval(unit) -def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: +def _datetrunc_eq_expression( + left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] +) -> exp.Expression: """Get the logical expression for a date range""" return exp.and_( - left >= date_literal(drange[0]), - left < date_literal(drange[1]), + left >= date_literal(drange[0], target_type), + left < date_literal(drange[1], target_type), copy=False, ) def _datetrunc_eq( - left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect + left: exp.Expression, + date: datetime.date, + unit: str, + dialect: Dialect, + target_type: t.Optional[exp.DataType], ) -> t.Optional[exp.Expression]: drange = _datetrunc_range(date, unit, dialect) if not drange: return None - return _datetrunc_eq_expression(left, drange) + return _datetrunc_eq_expression(left, drange, target_type) def _datetrunc_neq( - left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect + left: exp.Expression, + date: datetime.date, + unit: str, + dialect: Dialect, + target_type: t.Optional[exp.DataType], ) -> t.Optional[exp.Expression]: drange = _datetrunc_range(date, unit, dialect) if not drange: return None return exp.and_( - left < date_literal(drange[0]), - left >= date_literal(drange[1]), + left < date_literal(drange[0], target_type), + left >= date_literal(drange[1], target_type), copy=False, ) DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, dt, u, d: l - < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), - exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), - exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), - exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), + exp.LT: lambda l, dt, u, d, t: l + < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), + exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), + exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), + exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), exp.EQ: _datetrunc_eq, exp.NEQ: _datetrunc_neq, } @@ -876,9 +916,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr comparison = expression.__class__ if isinstance(expression, DATETRUNCS): - date = extract_date(expression.this) + this = expression.this + trunc_type = extract_type(this) + date = extract_date(this) if date and expression.unit: - return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) + return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) elif comparison not in DATETRUNC_COMPARISONS: return expression @@ -889,14 +931,21 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr return expression l = t.cast(exp.DateTrunc, l) + trunc_arg = l.this unit = l.unit.name.lower() date = extract_date(r) if not date: return expression - return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression - elif isinstance(expression, exp.In): + return ( + DATETRUNC_BINARY_COMPARISONS[comparison]( + trunc_arg, date, unit, dialect, extract_type(trunc_arg, r) + ) + or expression + ) + + if isinstance(expression, exp.In): l = expression.this rs = expression.expressions @@ -917,8 +966,11 @@ def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expr return expression ranges = merge_ranges(ranges) + target_type = extract_type(l, *rs) - return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) + return exp.or_( + *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False + ) return expression @@ -954,7 +1006,7 @@ JOINS = { def remove_where_true(expression): for where in expression.find_all(exp.Where): if always_true(where.this): - where.parent.set("where", None) + where.pop() for join in expression.find_all(exp.Join): if ( always_true(join.args.get("on")) @@ -962,7 +1014,7 @@ def remove_where_true(expression): and not join.args.get("method") and (join.side, join.kind) in JOINS ): - join.set("on", None) + join.args["on"].pop() join.set("side", None) join.set("kind", "CROSS") @@ -1067,15 +1119,25 @@ def extract_interval(expression): return None -def date_literal(date): - return exp.cast( - exp.Literal.string(date), - ( +def extract_type(*expressions): + target_type = None + for expression in expressions: + target_type = expression.to if isinstance(expression, exp.Cast) else expression.type + if target_type: + break + + return target_type + + +def date_literal(date, target_type=None): + if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): + target_type = ( exp.DataType.Type.DATETIME if isinstance(date, datetime.datetime) else exp.DataType.Type.DATE - ), - ) + ) + + return exp.cast(exp.Literal.string(date), target_type) def interval(unit: str, n: int = 1): @@ -1169,73 +1231,251 @@ def gen(expression: t.Any) -> str: Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here. """ - if expression is None: - return "_" - if is_iterable(expression): - return ",".join(gen(e) for e in expression) - if not isinstance(expression, exp.Expression): - return str(expression) - - etype = type(expression) - if etype in GEN_MAP: - return GEN_MAP[etype](expression) - return f"{expression.key} {gen(expression.args.values())}" - - -GEN_MAP = { - exp.Add: lambda e: _binary(e, "+"), - exp.And: lambda e: _binary(e, "AND"), - exp.Anonymous: lambda e: _anonymous(e), - exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", - exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", - exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", - exp.Column: lambda e: ".".join(gen(p) for p in e.parts), - exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", - exp.Div: lambda e: _binary(e, "/"), - exp.Dot: lambda e: _binary(e, "."), - exp.EQ: lambda e: _binary(e, "="), - exp.GT: lambda e: _binary(e, ">"), - exp.GTE: lambda e: _binary(e, ">="), - exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, - exp.ILike: lambda e: _binary(e, "ILIKE"), - exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", - exp.Is: lambda e: _binary(e, "IS"), - exp.Like: lambda e: _binary(e, "LIKE"), - exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, - exp.LT: lambda e: _binary(e, "<"), - exp.LTE: lambda e: _binary(e, "<="), - exp.Mod: lambda e: _binary(e, "%"), - exp.Mul: lambda e: _binary(e, "*"), - exp.Neg: lambda e: _unary(e, "-"), - exp.NEQ: lambda e: _binary(e, "<>"), - exp.Not: lambda e: _unary(e, "NOT"), - exp.Null: lambda e: "NULL", - exp.Or: lambda e: _binary(e, "OR"), - exp.Paren: lambda e: f"({gen(e.this)})", - exp.Sub: lambda e: _binary(e, "-"), - exp.Subquery: lambda e: f"({gen(e.args.values())})", - exp.Table: lambda e: gen(e.args.values()), - exp.Var: lambda e: e.name, -} + return Gen().gen(expression) + + +class Gen: + def __init__(self): + self.stack = [] + self.sqls = [] + + def gen(self, expression: exp.Expression) -> str: + self.stack = [expression] + self.sqls.clear() + + while self.stack: + node = self.stack.pop() + + if isinstance(node, exp.Expression): + exp_handler_name = f"{node.key}_sql" + + if hasattr(self, exp_handler_name): + getattr(self, exp_handler_name)(node) + elif isinstance(node, exp.Func): + self._function(node) + else: + key = node.key.upper() + self.stack.append(f"{key} " if self._args(node) else key) + elif type(node) is list: + for n in reversed(node): + if n is not None: + self.stack.extend((n, ",")) + if node: + self.stack.pop() + else: + if node is not None: + self.sqls.append(str(node)) + return "".join(self.sqls) -def _anonymous(e: exp.Anonymous) -> str: - this = e.this - if isinstance(this, str): - name = this.upper() - elif isinstance(this, exp.Identifier): - name = f'"{this.name}"' if this.quoted else this.name.upper() - else: - raise ValueError( - f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + def add_sql(self, e: exp.Add) -> None: + self._binary(e, " + ") + + def alias_sql(self, e: exp.Alias) -> None: + self.stack.extend( + ( + e.args.get("alias"), + " AS ", + e.args.get("this"), + ) + ) + + def and_sql(self, e: exp.And) -> None: + self._binary(e, " AND ") + + def anonymous_sql(self, e: exp.Anonymous) -> None: + this = e.this + if isinstance(this, str): + name = this.upper() + elif isinstance(this, exp.Identifier): + name = this.this + name = f'"{name}"' if this.quoted else name.upper() + else: + raise ValueError( + f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + ) + + self.stack.extend( + ( + ")", + e.expressions, + "(", + name, + ) + ) + + def between_sql(self, e: exp.Between) -> None: + self.stack.extend( + ( + e.args.get("high"), + " AND ", + e.args.get("low"), + " BETWEEN ", + e.this, + ) + ) + + def boolean_sql(self, e: exp.Boolean) -> None: + self.stack.append("TRUE" if e.this else "FALSE") + + def bracket_sql(self, e: exp.Bracket) -> None: + self.stack.extend( + ( + "]", + e.expressions, + "[", + e.this, + ) + ) + + def column_sql(self, e: exp.Column) -> None: + for p in reversed(e.parts): + self.stack.extend((p, ".")) + self.stack.pop() + + def datatype_sql(self, e: exp.DataType) -> None: + self._args(e, 1) + self.stack.append(f"{e.this.name} ") + + def div_sql(self, e: exp.Div) -> None: + self._binary(e, " / ") + + def dot_sql(self, e: exp.Dot) -> None: + self._binary(e, ".") + + def eq_sql(self, e: exp.EQ) -> None: + self._binary(e, " = ") + + def from_sql(self, e: exp.From) -> None: + self.stack.extend((e.this, "FROM ")) + + def gt_sql(self, e: exp.GT) -> None: + self._binary(e, " > ") + + def gte_sql(self, e: exp.GTE) -> None: + self._binary(e, " >= ") + + def identifier_sql(self, e: exp.Identifier) -> None: + self.stack.append(f'"{e.this}"' if e.quoted else e.this) + + def ilike_sql(self, e: exp.ILike) -> None: + self._binary(e, " ILIKE ") + + def in_sql(self, e: exp.In) -> None: + self.stack.append(")") + self._args(e, 1) + self.stack.extend( + ( + "(", + " IN ", + e.this, + ) ) - return f"{name} {','.join(gen(e) for e in e.expressions)}" + def intdiv_sql(self, e: exp.IntDiv) -> None: + self._binary(e, " DIV ") + + def is_sql(self, e: exp.Is) -> None: + self._binary(e, " IS ") + + def like_sql(self, e: exp.Like) -> None: + self._binary(e, " Like ") + + def literal_sql(self, e: exp.Literal) -> None: + self.stack.append(f"'{e.this}'" if e.is_string else e.this) + + def lt_sql(self, e: exp.LT) -> None: + self._binary(e, " < ") + + def lte_sql(self, e: exp.LTE) -> None: + self._binary(e, " <= ") + + def mod_sql(self, e: exp.Mod) -> None: + self._binary(e, " % ") + + def mul_sql(self, e: exp.Mul) -> None: + self._binary(e, " * ") + def neg_sql(self, e: exp.Neg) -> None: + self._unary(e, "-") + + def neq_sql(self, e: exp.NEQ) -> None: + self._binary(e, " <> ") + + def not_sql(self, e: exp.Not) -> None: + self._unary(e, "NOT ") + + def null_sql(self, e: exp.Null) -> None: + self.stack.append("NULL") + + def or_sql(self, e: exp.Or) -> None: + self._binary(e, " OR ") + + def paren_sql(self, e: exp.Paren) -> None: + self.stack.extend( + ( + ")", + e.this, + "(", + ) + ) + + def sub_sql(self, e: exp.Sub) -> None: + self._binary(e, " - ") + + def subquery_sql(self, e: exp.Subquery) -> None: + self._args(e, 2) + alias = e.args.get("alias") + if alias: + self.stack.append(alias) + self.stack.extend((")", e.this, "(")) + + def table_sql(self, e: exp.Table) -> None: + self._args(e, 4) + alias = e.args.get("alias") + if alias: + self.stack.append(alias) + for p in reversed(e.parts): + self.stack.extend((p, ".")) + self.stack.pop() + + def tablealias_sql(self, e: exp.TableAlias) -> None: + columns = e.columns + + if columns: + self.stack.extend((")", columns, "(")) + + self.stack.extend((e.this, " AS ")) + + def var_sql(self, e: exp.Var) -> None: + self.stack.append(e.this) + + def _binary(self, e: exp.Binary, op: str) -> None: + self.stack.extend((e.expression, op, e.this)) + + def _unary(self, e: exp.Unary, op: str) -> None: + self.stack.extend((e.this, op)) + + def _function(self, e: exp.Func) -> None: + self.stack.extend( + ( + ")", + list(e.args.values()), + "(", + e.sql_name(), + ) + ) -def _binary(e: exp.Binary, op: str) -> str: - return f"{gen(e.left)} {op} {gen(e.right)}" + def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: + kvs = [] + arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types + for k in arg_types or arg_types: + v = node.args.get(k) -def _unary(e: exp.Unary, op: str) -> str: - return f"{op} {gen(e.this)}" + if v is not None: + kvs.append([f":{k}", v]) + if kvs: + self.stack.append(kvs) + return True + return False diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 36d9da4..b83abe6 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -138,7 +138,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name): if isinstance(predicate, exp.Binary): key = ( predicate.right - if any(node is column for node, *_ in predicate.left.walk()) + if any(node is column for node in predicate.left.walk()) else predicate.left ) else: diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 49dac2e..91d8d13 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -15,6 +15,8 @@ if t.TYPE_CHECKING: from sqlglot._typing import E, Lit from sqlglot.dialects.dialect import Dialect, DialectType + T = t.TypeVar("T") + logger = logging.getLogger("sqlglot") OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]] @@ -119,6 +121,9 @@ class Parser(metaclass=_Parser): "JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar), "LIKE": build_like, "LOG": build_logarithm, + "LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)), + "LOG10": lambda args: exp.Log(this=exp.Literal.number(10), expression=seq_get(args, 0)), + "MOD": lambda args: exp.Mod(this=seq_get(args, 0), expression=seq_get(args, 1)), "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), @@ -144,6 +149,7 @@ class Parser(metaclass=_Parser): STRUCT_TYPE_TOKENS = { TokenType.NESTED, + TokenType.OBJECT, TokenType.STRUCT, } @@ -258,6 +264,7 @@ class Parser(metaclass=_Parser): TokenType.IPV6, TokenType.UNKNOWN, TokenType.NULL, + TokenType.NAME, *ENUM_TYPE_TOKENS, *NESTED_TYPE_TOKENS, *AGGREGATE_TYPE_TOKENS, @@ -291,6 +298,7 @@ class Parser(metaclass=_Parser): TokenType.VIEW, TokenType.MODEL, TokenType.DICTIONARY, + TokenType.SEQUENCE, TokenType.STORAGE_INTEGRATION, } @@ -310,6 +318,7 @@ class Parser(metaclass=_Parser): TokenType.ANTI, TokenType.APPLY, TokenType.ASC, + TokenType.ASOF, TokenType.AUTO_INCREMENT, TokenType.BEGIN, TokenType.BPCHAR, @@ -398,6 +407,8 @@ class Parser(metaclass=_Parser): TokenType.WINDOW, } + ALIAS_TOKENS = ID_VAR_TOKENS + COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS} UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} @@ -433,6 +444,7 @@ class Parser(metaclass=_Parser): TokenType.VAR, TokenType.LEFT, TokenType.RIGHT, + TokenType.SEQUENCE, TokenType.DATE, TokenType.DATETIME, TokenType.TABLE, @@ -505,8 +517,9 @@ class Parser(metaclass=_Parser): } JOIN_METHODS = { - TokenType.NATURAL, TokenType.ASOF, + TokenType.NATURAL, + TokenType.POSITIONAL, } JOIN_SIDES = { @@ -611,8 +624,8 @@ class Parser(metaclass=_Parser): TokenType.ALTER: lambda self: self._parse_alter(), TokenType.BEGIN: lambda self: self._parse_transaction(), TokenType.CACHE: lambda self: self._parse_cache(), - TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.COMMENT: lambda self: self._parse_comment(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DESC: lambda self: self._parse_describe(), @@ -627,9 +640,9 @@ class Parser(metaclass=_Parser): TokenType.REFRESH: lambda self: self._parse_refresh(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.SET: lambda self: self._parse_set(), + TokenType.TRUNCATE: lambda self: self._parse_truncate_table(), TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.UPDATE: lambda self: self._parse_update(), - TokenType.TRUNCATE: lambda self: self._parse_truncate_table(), TokenType.USE: lambda self: self.expression( exp.Use, kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False), @@ -714,6 +727,9 @@ class Parser(metaclass=_Parser): "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "AUTO": lambda self: self._parse_auto_property(), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), + "BACKUP": lambda self: self.expression( + exp.BackupProperty, this=self._parse_var(any_token=True) + ), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), "CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs), @@ -739,7 +755,9 @@ class Parser(metaclass=_Parser): "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs), "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "FREESPACE": lambda self: self._parse_freespace(), + "GLOBAL": lambda self: self.expression(exp.GlobalProperty), "HEAP": lambda self: self.expression(exp.HeapProperty), + "ICEBERG": lambda self: self.expression(exp.IcebergProperty), "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), @@ -782,6 +800,7 @@ class Parser(metaclass=_Parser): "SETTINGS": lambda self: self.expression( exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item) ), + "SHARING": lambda self: self._parse_property_assignment(exp.SharingProperty), "SORTKEY": lambda self: self._parse_sortkey(), "SOURCE": lambda self: self._parse_dict_property(this="SOURCE"), "STABLE": lambda self: self.expression( @@ -789,7 +808,7 @@ class Parser(metaclass=_Parser): ), "STORED": lambda self: self._parse_stored(), "SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(), - "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), + "TBLPROPERTIES": lambda self: self._parse_wrapped_properties(), "TEMP": lambda self: self.expression(exp.TemporaryProperty), "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), "TO": lambda self: self._parse_to_table(), @@ -799,6 +818,7 @@ class Parser(metaclass=_Parser): ), "TTL": lambda self: self._parse_ttl(), "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "UNLOGGED": lambda self: self.expression(exp.UnloggedProperty), "VOLATILE": lambda self: self._parse_volatile_property(), "WITH": lambda self: self._parse_with_property(), } @@ -832,6 +852,9 @@ class Parser(metaclass=_Parser): exp.DefaultColumnConstraint, this=self._parse_bitwise() ), "ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()), + "EXCLUDE": lambda self: self.expression( + exp.ExcludeColumnConstraint, this=self._parse_index_params() + ), "FOREIGN KEY": lambda self: self._parse_foreign_key(), "FORMAT": lambda self: self.expression( exp.DateFormatColumnConstraint, this=self._parse_var_or_string() @@ -858,7 +881,7 @@ class Parser(metaclass=_Parser): "UNIQUE": lambda self: self._parse_unique(), "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), "WITH": lambda self: self.expression( - exp.Properties, expressions=self._parse_wrapped_csv(self._parse_property) + exp.Properties, expressions=self._parse_wrapped_properties() ), } @@ -871,7 +894,15 @@ class Parser(metaclass=_Parser): "RENAME": lambda self: self._parse_alter_table_rename(), } - SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE", "PERIOD"} + SCHEMA_UNNAMED_CONSTRAINTS = { + "CHECK", + "EXCLUDE", + "FOREIGN KEY", + "LIKE", + "PERIOD", + "PRIMARY KEY", + "UNIQUE", + } NO_PAREN_FUNCTION_PARSERS = { "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()), @@ -966,18 +997,54 @@ class Parser(metaclass=_Parser): "READ": ("WRITE", "ONLY"), } + CONFLICT_ACTIONS: OPTIONS_TYPE = dict.fromkeys( + ("ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK", "UPDATE"), tuple() + ) + CONFLICT_ACTIONS["DO"] = ("NOTHING", "UPDATE") + + CREATE_SEQUENCE: OPTIONS_TYPE = { + "SCALE": ("EXTEND", "NOEXTEND"), + "SHARD": ("EXTEND", "NOEXTEND"), + "NO": ("CYCLE", "CACHE", "MAXVALUE", "MINVALUE"), + **dict.fromkeys( + ( + "SESSION", + "GLOBAL", + "KEEP", + "NOKEEP", + "ORDER", + "NOORDER", + "NOCACHE", + "CYCLE", + "NOCYCLE", + "NOMINVALUE", + "NOMAXVALUE", + "NOSCALE", + "NOSHARD", + ), + tuple(), + ), + } + + ISOLATED_LOADING_OPTIONS: OPTIONS_TYPE = {"FOR": ("ALL", "INSERT", "NONE")} + USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple()) + CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",)) + INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} CLONE_KEYWORDS = {"CLONE", "COPY"} HISTORICAL_DATA_KIND = {"TIMESTAMP", "OFFSET", "STATEMENT", "STREAM"} - OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS"} + OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"} + OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN} TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE} + VIEW_ATTRIBUTES = {"ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} @@ -994,6 +1061,8 @@ class Parser(metaclass=_Parser): UNNEST_OFFSET_ALIAS_TOKENS = ID_VAR_TOKENS - SET_OPERATIONS + SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT} + STRICT_CAST = True PREFIXED_PIVOT_COLUMNS = False @@ -1033,6 +1102,9 @@ class Parser(metaclass=_Parser): # Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift) SUPPORTS_IMPLICIT_UNNEST = False + # Whether or not interval spans are supported, INTERVAL 1 YEAR TO MONTHS + INTERVAL_SPANS = True + __slots__ = ( "error_level", "error_message_context", @@ -1285,6 +1357,27 @@ class Parser(metaclass=_Parser): exp.Command, this=self._prev.text.upper(), expression=self._parse_string() ) + def _try_parse(self, parse_method: t.Callable[[], T], retreat: bool = False) -> t.Optional[T]: + """ + Attemps to backtrack if a parse function that contains a try/catch internally raises an error. This behavior can + be different depending on the uset-set ErrorLevel, so _try_parse aims to solve this by setting & resetting + the parser state accordingly + """ + index = self._index + error_level = self.error_level + + self.error_level = ErrorLevel.IMMEDIATE + try: + this = parse_method() + except ParseError: + this = None + finally: + if not this or retreat: + self._retreat(index) + self.error_level = error_level + + return this + def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: start = self._prev exists = self._parse_exists() if allow_exists else None @@ -1377,13 +1470,22 @@ class Parser(metaclass=_Parser): if not kind: return self._parse_as_command(start) + if_exists = exists or self._parse_exists() + table = self._parse_table_parts( + schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA + ) + + if self._match(TokenType.L_PAREN, advance=False): + expressions = self._parse_wrapped_csv(self._parse_types) + else: + expressions = None + return self.expression( exp.Drop, comments=start.comments, - exists=exists or self._parse_exists(), - this=self._parse_table( - schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA - ), + exists=if_exists, + this=table, + expressions=expressions, kind=kind, temporary=temporary, materialized=materialized, @@ -1409,6 +1511,7 @@ class Parser(metaclass=_Parser): or self._match_pair(TokenType.OR, TokenType.REPLACE) or self._match_pair(TokenType.OR, TokenType.ALTER) ) + unique = self._match(TokenType.UNIQUE) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): @@ -1489,7 +1592,11 @@ class Parser(metaclass=_Parser): # exp.Properties.Location.POST_ALIAS extend_props(self._parse_properties()) - expression = self._parse_ddl_select() + if create_token.token_type == TokenType.SEQUENCE: + expression = self._parse_types() + extend_props(self._parse_properties()) + else: + expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: # exp.Properties.Location.POST_EXPRESSION @@ -1539,6 +1646,40 @@ class Parser(metaclass=_Parser): clone=clone, ) + def _parse_sequence_properties(self) -> t.Optional[exp.SequenceProperties]: + seq = exp.SequenceProperties() + + options = [] + index = self._index + + while self._curr: + if self._match_text_seq("INCREMENT"): + self._match_text_seq("BY") + self._match_text_seq("=") + seq.set("increment", self._parse_term()) + elif self._match_text_seq("MINVALUE"): + seq.set("minvalue", self._parse_term()) + elif self._match_text_seq("MAXVALUE"): + seq.set("maxvalue", self._parse_term()) + elif self._match(TokenType.START_WITH) or self._match_text_seq("START"): + self._match_text_seq("=") + seq.set("start", self._parse_term()) + elif self._match_text_seq("CACHE"): + # T-SQL allows empty CACHE which is initialized dynamically + seq.set("cache", self._parse_number() or True) + elif self._match_text_seq("OWNED", "BY"): + # "OWNED BY NONE" is the default + seq.set("owned", None if self._match_text_seq("NONE") else self._parse_column()) + else: + opt = self._parse_var_from_options(self.CREATE_SEQUENCE, raise_unmatched=False) + if opt: + options.append(opt) + else: + break + + seq.set("options", options if options else None) + return None if self._index == index else seq + def _parse_property_before(self) -> t.Optional[exp.Expression]: # only used for teradata currently self._match(TokenType.COMMA) @@ -1564,6 +1705,9 @@ class Parser(metaclass=_Parser): return None + def _parse_wrapped_properties(self) -> t.List[exp.Expression]: + return self._parse_wrapped_csv(self._parse_property) + def _parse_property(self) -> t.Optional[exp.Expression]: if self._match_texts(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.text.upper()](self) @@ -1582,12 +1726,12 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.EQ): self._retreat(index) - return None + return self._parse_sequence_properties() return self.expression( exp.Property, this=key.to_dot() if isinstance(key, exp.Column) else key, - value=self._parse_column() or self._parse_var(any_token=True), + value=self._parse_bitwise() or self._parse_var(any_token=True), ) def _parse_stored(self) -> exp.FileFormatProperty: @@ -1619,7 +1763,6 @@ class Parser(metaclass=_Parser): prop = self._parse_property_before() else: prop = self._parse_property() - if not prop: break for p in ensure_list(prop): @@ -1662,15 +1805,16 @@ class Parser(metaclass=_Parser): return prop - def _parse_with_property( - self, - ) -> t.Optional[exp.Expression] | t.List[exp.Expression]: + def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expression]: if self._match(TokenType.L_PAREN, advance=False): - return self._parse_wrapped_csv(self._parse_property) + return self._parse_wrapped_properties() if self._match_text_seq("JOURNAL"): return self._parse_withjournaltable() + if self._match_texts(self.VIEW_ATTRIBUTES): + return self.expression(exp.ViewAttributeProperty, this=self._prev.text.upper()) + if self._match_text_seq("DATA"): return self._parse_withdata(no=False) elif self._match_text_seq("NO", "DATA"): @@ -1818,20 +1962,18 @@ class Parser(metaclass=_Parser): autotemp=autotemp, ) - def _parse_withisolatedloading(self) -> exp.IsolatedLoadingProperty: + def _parse_withisolatedloading(self) -> t.Optional[exp.IsolatedLoadingProperty]: + index = self._index no = self._match_text_seq("NO") concurrent = self._match_text_seq("CONCURRENT") - self._match_text_seq("ISOLATED", "LOADING") - for_all = self._match_text_seq("FOR", "ALL") - for_insert = self._match_text_seq("FOR", "INSERT") - for_none = self._match_text_seq("FOR", "NONE") + + if not self._match_text_seq("ISOLATED", "LOADING"): + self._retreat(index) + return None + + target = self._parse_var_from_options(self.ISOLATED_LOADING_OPTIONS, raise_unmatched=False) return self.expression( - exp.IsolatedLoadingProperty, - no=no, - concurrent=concurrent, - for_all=for_all, - for_insert=for_insert, - for_none=for_none, + exp.IsolatedLoadingProperty, no=no, concurrent=concurrent, target=target ) def _parse_locking(self) -> exp.LockingProperty: @@ -2046,20 +2188,22 @@ class Parser(metaclass=_Parser): def _parse_describe(self) -> exp.Describe: kind = self._match_set(self.CREATABLES) and self._prev.text - extended = self._match_text_seq("EXTENDED") + style = self._match_texts(("EXTENDED", "FORMATTED")) and self._prev.text.upper() this = self._parse_table(schema=True) properties = self._parse_properties() expressions = properties.expressions if properties else None return self.expression( - exp.Describe, this=this, extended=extended, kind=kind, expressions=expressions + exp.Describe, this=this, style=style, kind=kind, expressions=expressions ) def _parse_insert(self) -> exp.Insert: comments = ensure_list(self._prev_comments) + hint = self._parse_hint() overwrite = self._match(TokenType.OVERWRITE) ignore = self._match(TokenType.IGNORE) local = self._match_text_seq("LOCAL") alternative = None + is_function = None if self._match_text_seq("DIRECTORY"): this: t.Optional[exp.Expression] = self.expression( @@ -2075,13 +2219,17 @@ class Parser(metaclass=_Parser): self._match(TokenType.INTO) comments += ensure_list(self._prev_comments) self._match(TokenType.TABLE) - this = self._parse_table(schema=True) + is_function = self._match(TokenType.FUNCTION) + + this = self._parse_table(schema=True) if not is_function else self._parse_function() returning = self._parse_returning() return self.expression( exp.Insert, comments=comments, + hint=hint, + is_function=is_function, this=this, by_name=self._match_text_seq("BY", "NAME"), exists=self._parse_exists(), @@ -2112,31 +2260,29 @@ class Parser(metaclass=_Parser): if not conflict and not duplicate: return None - nothing = None - expressions = None - key = None + conflict_keys = None constraint = None if conflict: if self._match_text_seq("ON", "CONSTRAINT"): constraint = self._parse_id_var() - else: - key = self._parse_csv(self._parse_value) + elif self._match(TokenType.L_PAREN): + conflict_keys = self._parse_csv(self._parse_id_var) + self._match_r_paren() - self._match_text_seq("DO") - if self._match_text_seq("NOTHING"): - nothing = True - else: - self._match(TokenType.UPDATE) + action = self._parse_var_from_options(self.CONFLICT_ACTIONS) + if self._prev.token_type == TokenType.UPDATE: self._match(TokenType.SET) expressions = self._parse_csv(self._parse_equality) + else: + expressions = None return self.expression( exp.OnConflict, duplicate=duplicate, expressions=expressions, - nothing=nothing, - key=key, + action=action, + conflict_keys=conflict_keys, constraint=constraint, ) @@ -2166,7 +2312,7 @@ class Parser(metaclass=_Parser): serde_properties = None if self._match(TokenType.SERDE_PROPERTIES): serde_properties = self.expression( - exp.SerdeProperties, expressions=self._parse_wrapped_csv(self._parse_property) + exp.SerdeProperties, expressions=self._parse_wrapped_properties() ) return self.expression( @@ -2433,8 +2579,19 @@ class Parser(metaclass=_Parser): self.raise_error("Expected CTE to have alias") self._match(TokenType.ALIAS) + + if self._match_text_seq("NOT", "MATERIALIZED"): + materialized = False + elif self._match_text_seq("MATERIALIZED"): + materialized = True + else: + materialized = None + return self.expression( - exp.CTE, this=self._parse_wrapped(self._parse_statement), alias=alias + exp.CTE, + this=self._parse_wrapped(self._parse_statement), + alias=alias, + materialized=materialized, ) def _parse_table_alias( @@ -2472,7 +2629,9 @@ class Parser(metaclass=_Parser): ) def _implicit_unnests_to_explicit(self, this: E) -> E: - from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm + from sqlglot.optimizer.normalize_identifiers import ( + normalize_identifiers as _norm, + ) refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name} for i, join in enumerate(this.args.get("joins") or []): @@ -2502,7 +2661,7 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: if isinstance(this, (exp.Query, exp.Table)): - for join in iter(self._parse_join, None): + for join in self._parse_joins(): this.append("joins", join) for lateral in iter(self._parse_lateral, None): this.append("laterals", lateral) @@ -2535,7 +2694,12 @@ class Parser(metaclass=_Parser): def _parse_hint(self) -> t.Optional[exp.Hint]: if self._match(TokenType.HINT): hints = [] - for hint in iter(lambda: self._parse_csv(self._parse_function), []): + for hint in iter( + lambda: self._parse_csv( + lambda: self._parse_function() or self._parse_var(upper=True) + ), + [], + ): hints.extend(hint) if not self._match_pair(TokenType.STAR, TokenType.SLASH): @@ -2743,29 +2907,35 @@ class Parser(metaclass=_Parser): if hint: kwargs["hint"] = hint + if self._match(TokenType.MATCH_CONDITION): + kwargs["match_condition"] = self._parse_wrapped(self._parse_comparison) + if self._match(TokenType.ON): kwargs["on"] = self._parse_conjunction() elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() - elif not (kind and kind.token_type == TokenType.CROSS): + elif not isinstance(kwargs["this"], exp.Unnest) and not ( + kind and kind.token_type == TokenType.CROSS + ): index = self._index - join = self._parse_join() + joins: t.Optional[list] = list(self._parse_joins()) - if join and self._match(TokenType.ON): + if joins and self._match(TokenType.ON): kwargs["on"] = self._parse_conjunction() - elif join and self._match(TokenType.USING): + elif joins and self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() else: - join = None + joins = None self._retreat(index) - kwargs["this"].set("joins", [join] if join else None) + kwargs["this"].set("joins", joins if joins else None) comments = [c for token in (method, side, kind) if token for c in token.comments] return self.expression(exp.Join, comments=comments, **kwargs) def _parse_opclass(self) -> t.Optional[exp.Expression]: this = self._parse_conjunction() + if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False): return this @@ -2774,6 +2944,35 @@ class Parser(metaclass=_Parser): return this + def _parse_index_params(self) -> exp.IndexParameters: + using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None + + if self._match(TokenType.L_PAREN, advance=False): + columns = self._parse_wrapped_csv(self._parse_with_operator) + else: + columns = None + + include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None + partition_by = self._parse_partition_by() + with_storage = self._match(TokenType.WITH) and self._parse_wrapped_properties() + tablespace = ( + self._parse_var(any_token=True) + if self._match_text_seq("USING", "INDEX", "TABLESPACE") + else None + ) + where = self._parse_where() + + return self.expression( + exp.IndexParameters, + using=using, + columns=columns, + include=include, + partition_by=partition_by, + where=where, + with_storage=with_storage, + tablespace=tablespace, + ) + def _parse_index( self, index: t.Optional[exp.Expression] = None, @@ -2797,27 +2996,16 @@ class Parser(metaclass=_Parser): index = self._parse_id_var() table = None - using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None - - if self._match(TokenType.L_PAREN, advance=False): - columns = self._parse_wrapped_csv(lambda: self._parse_ordered(self._parse_opclass)) - else: - columns = None - - include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None + params = self._parse_index_params() return self.expression( exp.Index, this=index, table=table, - using=using, - columns=columns, unique=unique, primary=primary, amp=amp, - include=include, - partition_by=self._parse_partition_by(), - where=self._parse_where(), + params=params, ) def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]: @@ -2977,7 +3165,7 @@ class Parser(metaclass=_Parser): this = table_sample if joins: - for join in iter(self._parse_join, None): + for join in self._parse_joins(): this.append("joins", join) if self._match_pair(TokenType.WITH, TokenType.ORDINALITY): @@ -3126,8 +3314,8 @@ class Parser(metaclass=_Parser): def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: return list(iter(self._parse_pivot, None)) or None - def _parse_joins(self) -> t.Optional[t.List[exp.Join]]: - return list(iter(self._parse_join, None)) or None + def _parse_joins(self) -> t.Iterator[exp.Join]: + return iter(self._parse_join, None) # https://duckdb.org/docs/sql/statements/pivot def _parse_simplified_pivot(self) -> exp.Pivot: @@ -3328,6 +3516,7 @@ class Parser(metaclass=_Parser): return None self._match(TokenType.CONNECT_BY) + nocycle = self._match_text_seq("NOCYCLE") self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression( exp.Prior, this=self._parse_bitwise() ) @@ -3337,7 +3526,7 @@ class Parser(metaclass=_Parser): if not start and self._match(TokenType.START_WITH): start = self._parse_conjunction() - return self.expression(exp.Connect, start=start, connect=connect) + return self.expression(exp.Connect, start=start, connect=connect, nocycle=nocycle) def _parse_name_as_expression(self) -> exp.Alias: return self.expression( @@ -3417,9 +3606,12 @@ class Parser(metaclass=_Parser): ) def _parse_limit( - self, this: t.Optional[exp.Expression] = None, top: bool = False + self, + this: t.Optional[exp.Expression] = None, + top: bool = False, + skip_limit_token: bool = False, ) -> t.Optional[exp.Expression]: - if self._match(TokenType.TOP if top else TokenType.LIMIT): + if skip_limit_token or self._match(TokenType.TOP if top else TokenType.LIMIT): comments = self._prev_comments if top: limit_paren = self._match(TokenType.L_PAREN) @@ -3681,6 +3873,11 @@ class Parser(metaclass=_Parser): this = exp.Literal.string(parts[0]) unit = self.expression(exp.Var, this=parts[1].upper()) + if self.INTERVAL_SPANS and self._match_text_seq("TO"): + unit = self.expression( + exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True) + ) + return self.expression(exp.Interval, this=this, unit=unit) def _parse_bitwise(self) -> t.Optional[exp.Expression]: @@ -3783,6 +3980,9 @@ class Parser(metaclass=_Parser): if not this: return None + if isinstance(this, exp.Column) and not this.table: + this = exp.var(this.name.upper()) + return self.expression( exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True) ) @@ -3900,19 +4100,14 @@ class Parser(metaclass=_Parser): elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): maybe_func = False elif type_token == TokenType.INTERVAL: - unit = self._parse_var() - - if self._match_text_seq("TO"): - span = [exp.IntervalSpan(this=unit, expression=self._parse_var())] - else: - span = None + unit = self._parse_var(upper=True) + if unit: + if self._match_text_seq("TO"): + unit = exp.IntervalSpan(this=unit, expression=self._parse_var(upper=True)) - if span or not unit: - this = self.expression( - exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span - ) - else: this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit)) + else: + this = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL) if maybe_func and check_func: index2 = self._index @@ -3996,11 +4191,20 @@ class Parser(metaclass=_Parser): else: field = self._parse_field(anonymous_func=True, any_token=True) - if isinstance(field, exp.Func): + if isinstance(field, exp.Func) and this: # bigquery allows function calls like x.y.count(...) # SAFE.SUBSTR(...) # https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules - this = self._replace_columns_with_dots(this) + this = exp.replace_tree( + this, + lambda n: ( + self.expression(exp.Dot, this=n.args.get("table"), expression=n.this) + if n.table + else n.this + ) + if isinstance(n, exp.Column) + else n, + ) if op: this = op(self, this, field) @@ -4050,10 +4254,14 @@ class Parser(metaclass=_Parser): this = self._parse_set_operations( self._parse_subquery(this=this, parse_alias=False) ) + elif isinstance(this, exp.Subquery): + this = self._parse_subquery( + this=self._parse_set_operations(this), parse_alias=False + ) elif len(expressions) > 1: this = self.expression(exp.Tuple, expressions=expressions) else: - this = self.expression(exp.Paren, this=self._parse_set_operations(this)) + this = self.expression(exp.Paren, this=this) if this: this.add_comments(comments) @@ -4118,7 +4326,7 @@ class Parser(metaclass=_Parser): parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper) if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS: self._advance() - return parser(self) + return self._parse_window(parser(self)) if not self._next or self._next.token_type != TokenType.L_PAREN: if optional_parens and token_type in self.NO_PAREN_FUNCTIONS: @@ -4186,7 +4394,7 @@ class Parser(metaclass=_Parser): if not isinstance(e, exp.PropertyEQ): e = self.expression( - exp.PropertyEQ, this=exp.to_identifier(e.name), expression=e.expression + exp.PropertyEQ, this=exp.to_identifier(e.this.name), expression=e.expression ) if isinstance(e.this, exp.Column): @@ -4267,19 +4475,15 @@ class Parser(metaclass=_Parser): def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index - if not self.errors: - try: - if self._parse_select(nested=True): - return this - except ParseError: - pass - finally: - self.errors.clear() - self._retreat(index) - if not self._match(TokenType.L_PAREN): return this + # Disambiguate between schema and subquery/CTE, e.g. in INSERT INTO table (<expr>), + # expr can be of both types + if self._match_set(self.SELECT_START_TOKENS): + self._retreat(index) + return this + args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def()) self._match_r_paren() @@ -4300,7 +4504,7 @@ class Parser(metaclass=_Parser): constraints: t.List[exp.Expression] = [] - if not kind and self._match(TokenType.ALIAS): + if (not kind and self._match(TokenType.ALIAS)) or self._match_text_seq("ALIAS"): constraints.append( self.expression( exp.ComputedColumnConstraint, @@ -4417,9 +4621,7 @@ class Parser(metaclass=_Parser): self._match_text_seq("LENGTH") return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise()) - def _parse_not_constraint( - self, - ) -> t.Optional[exp.Expression]: + def _parse_not_constraint(self) -> t.Optional[exp.Expression]: if self._match_text_seq("NULL"): return self.expression(exp.NotNullColumnConstraint) if self._match_text_seq("CASESPECIFIC"): @@ -4447,16 +4649,21 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.CONSTRAINT): return self._parse_unnamed_constraint(constraints=self.SCHEMA_UNNAMED_CONSTRAINTS) - this = self._parse_id_var() - expressions = [] + return self.expression( + exp.Constraint, + this=self._parse_id_var(), + expressions=self._parse_unnamed_constraints(), + ) + def _parse_unnamed_constraints(self) -> t.List[exp.Expression]: + constraints = [] while True: constraint = self._parse_unnamed_constraint() or self._parse_function() if not constraint: break - expressions.append(constraint) + constraints.append(constraint) - return self.expression(exp.Constraint, this=this, expressions=expressions) + return constraints def _parse_unnamed_constraint( self, constraints: t.Optional[t.Collection[str]] = None @@ -4478,6 +4685,7 @@ class Parser(metaclass=_Parser): exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)), index_type=self._match(TokenType.USING) and self._advance_any() and self._prev.text, + on_conflict=self._parse_on_conflict(), ) def _parse_key_constraint_options(self) -> t.List[str]: @@ -4592,7 +4800,7 @@ class Parser(metaclass=_Parser): def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True)) - def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + def _parse_bracket(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this @@ -4601,9 +4809,9 @@ class Parser(metaclass=_Parser): lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE) ) - if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: + if bracket_kind == TokenType.L_BRACKET and not self._match(TokenType.R_BRACKET): self.raise_error("Expected ]") - elif not self._match(TokenType.R_BRACE) and bracket_kind == TokenType.L_BRACE: + elif bracket_kind == TokenType.L_BRACE and not self._match(TokenType.R_BRACE): self.raise_error("Expected }") # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs @@ -4645,8 +4853,8 @@ class Parser(metaclass=_Parser): else: self.raise_error("Expected END after CASE", self._prev) - return self._parse_window( - self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default) + return self.expression( + exp.Case, comments=comments, this=expression, ifs=ifs, default=default ) def _parse_if(self) -> t.Optional[exp.Expression]: @@ -4672,7 +4880,7 @@ class Parser(metaclass=_Parser): self._match(TokenType.END) this = self.expression(exp.If, this=condition, true=true, false=false) - return self._parse_window(this) + return this def _parse_next_value_for(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("VALUE", "FOR"): @@ -4739,7 +4947,12 @@ class Parser(metaclass=_Parser): to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) return self.expression( - exp.Cast if strict else exp.TryCast, this=this, to=to, format=fmt, safe=safe + exp.Cast if strict else exp.TryCast, + this=this, + to=to, + format=fmt, + safe=safe, + action=self._parse_var_from_options(self.CAST_ACTIONS, raise_unmatched=False), ) def _parse_string_agg(self) -> exp.Expression: @@ -5087,6 +5300,9 @@ class Parser(metaclass=_Parser): def _parse_window( self, this: t.Optional[exp.Expression], alias: bool = False ) -> t.Optional[exp.Expression]: + func = this + comments = func.comments if isinstance(func, exp.Expression) else None + if self._match_pair(TokenType.FILTER, TokenType.L_PAREN): self._match(TokenType.WHERE) this = self.expression( @@ -5132,9 +5348,16 @@ class Parser(metaclass=_Parser): else: over = self._prev.text.upper() + if comments: + func.comments = None # type: ignore + if not self._match(TokenType.L_PAREN): return self.expression( - exp.Window, this=this, alias=self._parse_id_var(False), over=over + exp.Window, + comments=comments, + this=this, + alias=self._parse_id_var(False), + over=over, ) window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) @@ -5167,6 +5390,7 @@ class Parser(metaclass=_Parser): window = self.expression( exp.Window, + comments=comments, this=this, partition_by=partition, order=order, @@ -5218,7 +5442,7 @@ class Parser(metaclass=_Parser): self._match_r_paren(aliases) return aliases - alias = self._parse_id_var(any_token) or ( + alias = self._parse_id_var(any_token, tokens=self.ALIAS_TOKENS) or ( self.STRING_ALIASES and self._parse_string_as_identifier() ) @@ -5512,10 +5736,11 @@ class Parser(metaclass=_Parser): return self.expression(exp.AlterColumn, this=column, comment=self._parse_string()) self._match_text_seq("SET", "DATA") + self._match_text_seq("TYPE") return self.expression( exp.AlterColumn, this=column, - dtype=self._match_text_seq("TYPE") and self._parse_types(), + dtype=self._parse_types(), collate=self._match(TokenType.COLLATE) and self._parse_term(), using=self._match(TokenType.USING) and self._parse_conjunction(), ) @@ -5919,26 +6144,6 @@ class Parser(metaclass=_Parser): return True - @t.overload - def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ... - - @t.overload - def _replace_columns_with_dots( - self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: ... - - def _replace_columns_with_dots(self, this): - if isinstance(this, exp.Dot): - exp.replace_children(this, self._replace_columns_with_dots) - elif isinstance(this, exp.Column): - exp.replace_children(this, self._replace_columns_with_dots) - table = this.args.get("table") - this = ( - self.expression(exp.Dot, this=table, expression=this.this) if table else this.this - ) - - return this - def _replace_lambda( self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str] ) -> t.Optional[exp.Expression]: @@ -6011,3 +6216,13 @@ class Parser(metaclass=_Parser): option=option, partition=partition, ) + + def _parse_with_operator(self) -> t.Optional[exp.Expression]: + this = self._parse_ordered(self._parse_opclass) + + if not self._match(TokenType.WITH): + return this + + op = self._parse_var(any_token=True) + + return self.expression(exp.WithOperator, this=this, op=op) diff --git a/sqlglot/planner.py b/sqlglot/planner.py index bbc52ab..5e4e23a 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -118,6 +118,7 @@ class Step: if joins: join = Join.from_joins(joins, ctes) join.name = step.name + join.source_name = step.name join.add_dependency(step) step = join @@ -187,13 +188,13 @@ class Step: intermediate[v.name] = k for projection in projections: - for node, *_ in projection.walk(): + for node in projection.walk(): name = intermediate.get(node) if name: node.replace(exp.column(name, step.name)) if aggregate.condition: - for node, *_ in aggregate.condition.walk(): + for node in aggregate.condition.walk(): name = intermediate.get(node) or intermediate.get(node.name) if name: node.replace(exp.column(name, step.name)) @@ -331,7 +332,7 @@ class Join(Step): @classmethod def from_joins( cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None - ) -> Step: + ) -> Join: step = Join() for join in joins: @@ -349,10 +350,11 @@ class Join(Step): def __init__(self) -> None: super().__init__() + self.source_name: t.Optional[str] = None self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {} def _to_s(self, indent: str) -> t.List[str]: - lines = [] + lines = [f"{indent}Source: {self.source_name or self.name}"] for name, join in self.joins.items(): lines.append(f"{indent}{name}: {join['side'] or 'INNER'}") join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or [])) @@ -423,7 +425,7 @@ class SetOperation(Step): @classmethod def from_expression( cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None - ) -> Step: + ) -> SetOperation: assert isinstance(expression, exp.Union) left = Step.from_expression(expression.left, ctes) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index da9df7d..7f0cb5d 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -135,6 +135,7 @@ class TokenType(AutoName): LONGBLOB = auto() TINYBLOB = auto() TINYTEXT = auto() + NAME = auto() BINARY = auto() VARBINARY = auto() JSON = auto() @@ -290,6 +291,7 @@ class TokenType(AutoName): LOAD = auto() LOCK = auto() MAP = auto() + MATCH_CONDITION = auto() MATCH_RECOGNIZE = auto() MEMBER_OF = auto() MERGE = auto() @@ -317,6 +319,7 @@ class TokenType(AutoName): PERCENT = auto() PIVOT = auto() PLACEHOLDER = auto() + POSITIONAL = auto() PRAGMA = auto() PREWHERE = auto() PRIMARY_KEY = auto() @@ -340,6 +343,7 @@ class TokenType(AutoName): SELECT = auto() SEMI = auto() SEPARATOR = auto() + SEQUENCE = auto() SERDE_PROPERTIES = auto() SET = auto() SETTINGS = auto() @@ -518,6 +522,7 @@ class _Tokenizer(type): break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK], dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON], heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING], + raw_string=_TOKEN_TYPE_TO_INDEX[TokenType.RAW_STRING], hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING], identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER], number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER], @@ -562,8 +567,7 @@ class Tokenizer(metaclass=_Tokenizer): "~": TokenType.TILDA, "?": TokenType.PLACEHOLDER, "@": TokenType.PARAMETER, - # used for breaking a var like x'y' but nothing else - # the token type doesn't matter + # Used for breaking a var like x'y' but nothing else the token type doesn't matter "'": TokenType.QUOTE, "`": TokenType.IDENTIFIER, '"': TokenType.IDENTIFIER, @@ -796,6 +800,7 @@ class Tokenizer(metaclass=_Tokenizer): "LONG": TokenType.BIGINT, "BIGINT": TokenType.BIGINT, "INT8": TokenType.TINYINT, + "UINT": TokenType.UINT, "DEC": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL, "BIGDECIMAL": TokenType.BIGDECIMAL, @@ -856,6 +861,7 @@ class Tokenizer(metaclass=_Tokenizer): "DATEMULTIRANGE": TokenType.DATEMULTIRANGE, "UNIQUE": TokenType.UNIQUE, "STRUCT": TokenType.STRUCT, + "SEQUENCE": TokenType.SEQUENCE, "VARIANT": TokenType.VARIANT, "ALTER": TokenType.ALTER, "ANALYZE": TokenType.COMMAND, @@ -888,7 +894,7 @@ class Tokenizer(metaclass=_Tokenizer): COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN} - # handle numeric literals like in hive (3L = BIGINT) + # Handle numeric literals like in hive (3L = BIGINT) NUMERIC_LITERALS: t.Dict[str, str] = {} COMMENTS = ["--", ("/*", "*/")] @@ -917,7 +923,7 @@ class Tokenizer(metaclass=_Tokenizer): if USE_RS_TOKENIZER: self._rs_dialect_settings = RsTokenizerDialectSettings( - escape_sequences=self.dialect.ESCAPE_SEQUENCES, + unescaped_sequences=self.dialect.UNESCAPED_SEQUENCES, identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT, ) @@ -961,8 +967,7 @@ class Tokenizer(metaclass=_Tokenizer): while self.size and not self._end: current = self._current - # skip spaces inline rather than iteratively call advance() - # for performance reasons + # Skip spaces here rather than iteratively calling advance() for performance reasons while current < self.size: char = self.sql[current] @@ -971,12 +976,10 @@ class Tokenizer(metaclass=_Tokenizer): else: break - n = current - self._current - self._start = current - self._advance(n if n > 1 else 1) + offset = current - self._current if current > self._current else 1 - if self._char is None: - break + self._start = current + self._advance(offset) if not self._char.isspace(): if self._char.isdigit(): @@ -1004,12 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer): def _advance(self, i: int = 1, alnum: bool = False) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: # Ensures we don't count an extra line if we get a \r\n line break sequence - if self._char == "\r" and self._peek == "\n": - i = 2 - self._start += 1 - - self._col = 1 - self._line += 1 + if not (self._char == "\r" and self._peek == "\n"): + self._col = 1 + self._line += 1 else: self._col += i @@ -1268,13 +1268,27 @@ class Tokenizer(metaclass=_Tokenizer): return True self._advance() - tag = "" if self._char == end else self._extract_string(end) + + if self._char == end: + tag = "" + else: + tag = self._extract_string( + end, + unescape_sequences=False, + raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER, + ) + + if self._end and tag and self.HEREDOC_TAG_IS_IDENTIFIER: + self._advance(-len(tag)) + self._add(self.HEREDOC_STRING_ALTERNATIVE) + return True + end = f"{start}{tag}{end}" else: return False self._advance(len(start)) - text = self._extract_string(end) + text = self._extract_string(end, unescape_sequences=token_type != TokenType.RAW_STRING) if base: try: @@ -1289,7 +1303,7 @@ class Tokenizer(metaclass=_Tokenizer): def _scan_identifier(self, identifier_end: str) -> None: self._advance() - text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES) + text = self._extract_string(identifier_end, escapes=self._IDENTIFIER_ESCAPES) self._add(TokenType.IDENTIFIER, text) def _scan_var(self) -> None: @@ -1306,13 +1320,30 @@ class Tokenizer(metaclass=_Tokenizer): else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) ) - def _extract_string(self, delimiter: str, escapes=None) -> str: + def _extract_string( + self, + delimiter: str, + escapes: t.Optional[t.Set[str]] = None, + unescape_sequences: bool = True, + raise_unmatched: bool = True, + ) -> str: text = "" delim_size = len(delimiter) escapes = self._STRING_ESCAPES if escapes is None else escapes while True: if ( + unescape_sequences + and self.dialect.UNESCAPED_SEQUENCES + and self._peek + and self._char in self.STRING_ESCAPES + ): + unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get(self._char + self._peek) + if unescaped_sequence: + self._advance(2) + text += unescaped_sequence + continue + if ( self._char in escapes and (self._peek == delimiter or self._peek in escapes) and (self._char not in self._QUOTES or self._char == self._peek) @@ -1333,18 +1364,10 @@ class Tokenizer(metaclass=_Tokenizer): break if self._end: - raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") + if not raise_unmatched: + return text + self._char - if ( - self.dialect.ESCAPE_SEQUENCES - and self._peek - and self._char in self.STRING_ESCAPES - ): - escaped_sequence = self.dialect.ESCAPE_SEQUENCES.get(self._char + self._peek) - if escaped_sequence: - self._advance(2) - text += escaped_sequence - continue + raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") current = self._current - 1 self._advance(alnum=True) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 04c1f7b..f44c18c 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -447,7 +447,7 @@ def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: if inner_with.recursive: top_level_with.set("recursive", True) - top_level_with.expressions.extend(inner_with.expressions) + top_level_with.set("expressions", inner_with.expressions + top_level_with.expressions) return expression @@ -464,7 +464,7 @@ def ensure_bools(expression: exp.Expression) -> exp.Expression: ): node.replace(node.neq(0)) - for node, *_ in expression.walk(): + for node in expression.walk(): ensure_bools(node, _ensure_bool) return expression @@ -561,9 +561,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: - """ - Convert struct arguments to aliases: STRUCT(1 AS y) . - """ + """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" if isinstance(expression, exp.Struct): expression.set( "expressions", |