From 67c28dbe67209effad83d93b850caba5ee1e20e3 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 3 May 2023 11:12:28 +0200 Subject: Merging upstream version 11.7.1. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 5 +- sqlglot/dataframe/sql/column.py | 2 +- sqlglot/dataframe/sql/dataframe.py | 151 ++++++++------ sqlglot/dataframe/sql/functions.py | 15 +- sqlglot/dataframe/sql/readwriter.py | 12 +- sqlglot/dialects/bigquery.py | 61 +++--- sqlglot/dialects/clickhouse.py | 7 + sqlglot/dialects/databricks.py | 8 +- sqlglot/dialects/dialect.py | 37 +++- sqlglot/dialects/drill.py | 4 + sqlglot/dialects/duckdb.py | 11 +- sqlglot/dialects/hive.py | 64 +++--- sqlglot/dialects/mysql.py | 63 +++--- sqlglot/dialects/oracle.py | 32 ++- sqlglot/dialects/postgres.py | 36 +++- sqlglot/dialects/presto.py | 85 +++++--- sqlglot/dialects/redshift.py | 17 ++ sqlglot/dialects/snowflake.py | 59 ++++-- sqlglot/dialects/spark.py | 54 ++++- sqlglot/dialects/sqlite.py | 10 +- sqlglot/dialects/starrocks.py | 8 +- sqlglot/dialects/tableau.py | 8 + sqlglot/dialects/teradata.py | 21 +- sqlglot/dialects/tsql.py | 37 ++++ sqlglot/expressions.py | 286 +++++++++++++++++++++++--- sqlglot/generator.py | 205 ++++++++++++++----- sqlglot/helper.py | 34 +++- sqlglot/lineage.py | 53 +++-- sqlglot/optimizer/annotate_types.py | 7 +- sqlglot/optimizer/normalize.py | 4 +- sqlglot/optimizer/qualify_columns.py | 5 +- sqlglot/optimizer/qualify_tables.py | 3 + sqlglot/optimizer/simplify.py | 35 ++-- sqlglot/parser.py | 382 ++++++++++++++++++++++++++--------- sqlglot/schema.py | 156 +++++++------- sqlglot/tokens.py | 163 ++++++++------- sqlglot/transforms.py | 135 +++++++++++-- sqlglot/trie.py | 5 +- 38 files changed, 1650 insertions(+), 630 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 1feb464..42d89d1 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -21,10 +21,12 @@ from sqlglot.expressions import ( Expression as Expression, alias_ as alias, and_ as and_, + cast as cast, column as column, condition as condition, except_ as except_, from_ as from_, + func as func, intersect as intersect, maybe_parse as maybe_parse, not_ as not_, @@ -33,6 +35,7 @@ from sqlglot.expressions import ( subquery as subquery, table_ as table, to_column as to_column, + to_identifier as to_identifier, to_table as to_table, union as union, ) @@ -47,7 +50,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.5.2" +__version__ = "11.7.1" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index 609b2a4..a8b89d1 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -176,7 +176,7 @@ class Column: return isinstance(self.expression, exp.Column) @property - def column_expression(self) -> exp.Column: + def column_expression(self) -> t.Union[exp.Column, exp.Literal]: return self.expression.unalias() @property diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 93bdf75..f3a6f6f 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -16,7 +16,7 @@ from sqlglot.dataframe.sql.readwriter import DataFrameWriter from sqlglot.dataframe.sql.transforms import replace_id_value from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join from sqlglot.dataframe.sql.window import Window -from sqlglot.helper import ensure_list, object_to_dict +from sqlglot.helper import ensure_list, object_to_dict, seq_get from sqlglot.optimizer import optimize as optimize_func if t.TYPE_CHECKING: @@ -146,9 +146,9 @@ class DataFrame: def _ensure_list_of_columns(self, cols): return Column.ensure_cols(ensure_list(cols)) - def _ensure_and_normalize_cols(self, cols): + def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): cols = self._ensure_list_of_columns(cols) - normalize(self.spark, self.expression, cols) + normalize(self.spark, expression or self.expression, cols) return cols def _ensure_and_normalize_col(self, col): @@ -355,12 +355,20 @@ class DataFrame: cols = self._ensure_and_normalize_cols(cols) kwargs["append"] = kwargs.get("append", False) if self.expression.args.get("joins"): - ambiguous_cols = [col for col in cols if not col.column_expression.table] + ambiguous_cols = [ + col + for col in cols + if isinstance(col.column_expression, exp.Column) and not col.column_expression.table + ] if ambiguous_cols: join_table_identifiers = [ x.this for x in get_tables_from_expression_with_join(self.expression) ] cte_names_in_join = [x.this for x in join_table_identifiers] + # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right + # and therefore we allow multiple columns with the same name in the result. This matches the behavior + # of Spark. + resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} for ambiguous_col in ambiguous_cols: ctes_with_column = [ cte @@ -368,13 +376,14 @@ class DataFrame: if cte.alias_or_name in cte_names_in_join and ambiguous_col.alias_or_name in cte.this.named_selects ] - # If the select column does not specify a table and there is a join - # then we assume they are referring to the left table - if len(ctes_with_column) > 1: - table_identifier = self.expression.args["from"].args["expressions"][0].this + # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, + # use the same CTE we used before + cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) + if cte: + resolved_column_position[ambiguous_col] += 1 else: - table_identifier = ctes_with_column[0].args["alias"].this - ambiguous_col.expression.set("table", table_identifier) + cte = ctes_with_column[resolved_column_position[ambiguous_col]] + ambiguous_col.expression.set("table", cte.alias_or_name) return self.copy( expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs ) @@ -416,59 +425,87 @@ class DataFrame: **kwargs, ) -> DataFrame: other_df = other_df._convert_leaf_to_cte() - pre_join_self_latest_cte_name = self.latest_cte_name - columns = self._ensure_and_normalize_cols(on) - join_type = how.replace("_", " ") - if isinstance(columns[0].expression, exp.Column): - join_columns = [ - Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns + join_columns = self._ensure_list_of_columns(on) + # We will determine actual "join on" expression later so we don't provide it at first + join_expression = self.expression.join( + other_df.latest_cte_name, join_type=how.replace("_", " ") + ) + join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) + self_columns = self._get_outer_select_columns(join_expression) + other_columns = self._get_outer_select_columns(other_df) + # Determines the join clause and select columns to be used passed on what type of columns were provided for + # the join. The columns returned changes based on how the on expression is provided. + if isinstance(join_columns[0].expression, exp.Column): + """ + Unique characteristics of join on column names only: + * The column names are put at the front of the select list + * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) + """ + table_names = [ + table.alias_or_name + for table in get_tables_from_expression_with_join(join_expression) ] + potential_ctes = [ + cte + for cte in join_expression.ctes + if cte.alias_or_name in table_names + and cte.alias_or_name != other_df.latest_cte_name + ] + # Determine the table to reference for the left side of the join by checking each of the left side + # tables and see if they have the column being referenced. + join_column_pairs = [] + for join_column in join_columns: + num_matching_ctes = 0 + for cte in potential_ctes: + if join_column.alias_or_name in cte.this.named_selects: + left_column = join_column.copy().set_table_name(cte.alias_or_name) + right_column = join_column.copy().set_table_name(other_df.latest_cte_name) + join_column_pairs.append((left_column, right_column)) + num_matching_ctes += 1 + if num_matching_ctes > 1: + raise ValueError( + f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." + ) + elif num_matching_ctes == 0: + raise ValueError( + f"Column {join_column.alias_or_name} does not exist in any of the tables." + ) join_clause = functools.reduce( lambda x, y: x & y, - [ - col.copy().set_table_name(pre_join_self_latest_cte_name) - == col.copy().set_table_name(other_df.latest_cte_name) - for col in columns - ], + [left_column == right_column for left_column, right_column in join_column_pairs], ) - else: - if len(columns) > 1: - columns = [functools.reduce(lambda x, y: x & y, columns)] - join_clause = columns[0] - join_columns = [ - Column(x).set_table_name(pre_join_self_latest_cte_name) - if i % 2 == 0 - else Column(x).set_table_name(other_df.latest_cte_name) - for i, x in enumerate(join_clause.expression.find_all(exp.Column)) + join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] + # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list + select_column_names = [ + column.alias_or_name + if not isinstance(column.expression.this, exp.Star) + else column.sql() + for column in self_columns + other_columns ] - self_columns = [ - column.set_table_name(pre_join_self_latest_cte_name, copy=True) - for column in self._get_outer_select_columns(self) - ] - other_columns = [ - column.set_table_name(other_df.latest_cte_name, copy=True) - for column in self._get_outer_select_columns(other_df) - ] - column_value_mapping = { - column.alias_or_name - if not isinstance(column.expression.this, exp.Star) - else column.sql(): column - for column in other_columns + self_columns + join_columns - } - all_columns = [ - column_value_mapping[name] - for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns} - ] - new_df = self.copy( - expression=self.expression.join( - other_df.latest_cte_name, on=join_clause.expression, join_type=join_type - ) - ) - new_df.expression = new_df._add_ctes_to_expression( - new_df.expression, other_df.expression.ctes - ) + select_column_names = [ + column_name + for column_name in select_column_names + if column_name not in join_column_names + ] + select_column_names = join_column_names + select_column_names + else: + """ + Unique characteristics of join on expressions: + * There is no deduplication of the results. + * The left join dataframe columns go first and right come after. No sort preference is given to join columns + """ + join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) + if len(join_columns) > 1: + join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] + join_clause = join_columns[0] + select_column_names = [column.alias_or_name for column in self_columns + other_columns] + + # Update the on expression with the actual join clause to replace the dummy one from before + join_expression.args["joins"][-1].set("on", join_clause.expression) + new_df = self.copy(expression=join_expression) + new_df.pending_join_hints.extend(self.pending_join_hints) new_df.pending_hints.extend(other_df.pending_hints) - new_df = new_df.select.__wrapped__(new_df, *all_columns) + new_df = new_df.select.__wrapped__(new_df, *select_column_names) return new_df @operation(Operation.ORDER_BY) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index f77b4f8..993d869 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -577,11 +577,15 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, expression.DateAdd, expression=days) + return Column.invoke_expression_over_column( + col, expression.DateAdd, expression=days, unit=expression.Var(this="day") + ) def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column(col, expression.DateSub, expression=days) + return Column.invoke_expression_over_column( + col, expression.DateSub, expression=days, unit=expression.Var(this="day") + ) def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: @@ -695,18 +699,17 @@ def crc32(col: ColumnOrName) -> Column: def md5(col: ColumnOrName) -> Column: column = col if isinstance(col, Column) else lit(col) - return Column.invoke_anonymous_function(column, "MD5") + return Column.invoke_expression_over_column(column, expression.MD5) def sha1(col: ColumnOrName) -> Column: column = col if isinstance(col, Column) else lit(col) - return Column.invoke_anonymous_function(column, "SHA1") + return Column.invoke_expression_over_column(column, expression.SHA) def sha2(col: ColumnOrName, numBits: int) -> Column: column = col if isinstance(col, Column) else lit(col) - num_bits = lit(numBits) - return Column.invoke_anonymous_function(column, "SHA2", num_bits) + return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits)) def hash(*cols: ColumnOrName) -> Column: diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py index febc664..cc2f181 100644 --- a/sqlglot/dataframe/sql/readwriter.py +++ b/sqlglot/dataframe/sql/readwriter.py @@ -4,7 +4,7 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.helper import object_to_dict +from sqlglot.helper import object_to_dict, should_identify if t.TYPE_CHECKING: from sqlglot.dataframe.sql.dataframe import DataFrame @@ -19,9 +19,17 @@ class DataFrameReader: from sqlglot.dataframe.sql.dataframe import DataFrame sqlglot.schema.add_table(tableName) + return DataFrame( self.spark, - exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)), + exp.Select() + .from_(tableName) + .select( + *( + column if should_identify(column, "safe") else f'"{column}"' + for column in sqlglot.schema.column_names(tableName) + ) + ), ) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 701377b..1a88654 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, no_ilike_sql, + parse_date_delta_with_interval, rename_func, timestrtotime_sql, ts_or_ds_to_date_sql, @@ -23,18 +24,6 @@ from sqlglot.tokens import TokenType E = t.TypeVar("E", bound=exp.Expression) -def _date_add(expression_class: t.Type[E]) -> t.Callable[[t.Sequence], E]: - def func(args): - interval = seq_get(args, 1) - return expression_class( - this=seq_get(args, 0), - expression=interval.this, - unit=interval.args.get("unit"), - ) - - return func - - def _date_add_sql( data_type: str, kind: str ) -> t.Callable[[generator.Generator, exp.Expression], str]: @@ -142,6 +131,7 @@ class BigQuery(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "ANY TYPE": TokenType.VARIANT, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, @@ -155,14 +145,19 @@ class BigQuery(Dialect): KEYWORDS.pop("DIV") class Parser(parser.Parser): + PREFIXED_PIVOT_COLUMNS = True + + LOG_BASE_FIRST = False + LOG_DEFAULTS_TO_LN = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore this=seq_get(args, 0), ), - "DATE_ADD": _date_add(exp.DateAdd), - "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), + "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( @@ -174,12 +169,12 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), - "TIME_ADD": _date_add(exp.TimeAdd), - "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), - "DATE_SUB": _date_add(exp.DateSub), - "DATETIME_SUB": _date_add(exp.DatetimeSub), - "TIME_SUB": _date_add(exp.TimeSub), - "TIMESTAMP_SUB": _date_add(exp.TimestampSub), + "TIME_ADD": parse_date_delta_with_interval(exp.TimeAdd), + "TIMESTAMP_ADD": parse_date_delta_with_interval(exp.TimestampAdd), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), + "TIME_SUB": parse_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP_SUB": parse_date_delta_with_interval(exp.TimestampSub), "PARSE_TIMESTAMP": lambda args: exp.StrToTime( this=seq_get(args, 1), format=seq_get(args, 0) ), @@ -209,14 +204,17 @@ class BigQuery(Dialect): PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, # type: ignore "NOT DETERMINISTIC": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") + exp.StabilityProperty, this=exp.Literal.string("VOLATILE") ), } - LOG_BASE_FIRST = False - LOG_DEFAULTS_TO_LN = True - class Generator(generator.Generator): + EXPLICIT_UNION = True + INTERVAL_ALLOWS_PLURAL_FORM = False + JOIN_HINTS = False + TABLE_HINTS = False + LIMIT_FETCH = "LIMIT" + TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore @@ -236,9 +234,7 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.Max: max_or_greatest, exp.Min: min_or_least, - exp.Select: transforms.preprocess( - [_unqualify_unnest], transforms.delegate("select_sql") - ), + exp.Select: transforms.preprocess([_unqualify_unnest]), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -253,7 +249,7 @@ class BigQuery(Dialect): exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, exp.Trim: lambda self, e: self.func(f"TRIM", e.this, e.expression), - exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" + exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", exp.RegexpLike: rename_func("REGEXP_CONTAINS"), @@ -261,6 +257,7 @@ class BigQuery(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore + exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", exp.DataType.Type.BIGINT: "INT64", exp.DataType.Type.BOOLEAN: "BOOL", exp.DataType.Type.CHAR: "STRING", @@ -272,17 +269,19 @@ class BigQuery(Dialect): exp.DataType.Type.NVARCHAR: "STRING", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.VARIANT: "ANY TYPE", } + PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - EXPLICIT_UNION = True - LIMIT_FETCH = "LIMIT" - def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) if isinstance(first_arg, exp.Subqueryable): diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b06462c..e91b0bf 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -144,6 +144,13 @@ class ClickHouse(Dialect): exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + + JOIN_HINTS = False + TABLE_HINTS = False EXPLICIT_UNION = True def _param_args_sql( diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2f93ee7..138f26c 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -9,6 +9,8 @@ from sqlglot.tokens import TokenType class Databricks(Spark): class Parser(Spark.Parser): + LOG_DEFAULTS_TO_LN = True + FUNCTIONS = { **Spark.Parser.FUNCTIONS, "DATEADD": parse_date_delta(exp.DateAdd), @@ -16,13 +18,17 @@ class Databricks(Spark): "DATEDIFF": parse_date_delta(exp.DateDiff), } - LOG_DEFAULTS_TO_LN = True + FACTOR = { + **Spark.Parser.FACTOR, + TokenType.COLON: exp.JSONExtract, + } class Generator(Spark.Generator): TRANSFORMS = { **Spark.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, + exp.JSONExtract: lambda self, e: self.binary(e, ":"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), } TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 839589d..19c6f73 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -293,6 +293,13 @@ def no_properties_sql(self: Generator, expression: exp.Properties) -> str: return "" +def no_comment_column_constraint_sql( + self: Generator, expression: exp.CommentColumnConstraint +) -> str: + self.unsupported("CommentColumnConstraint unsupported") + return "" + + def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") @@ -379,15 +386,35 @@ def parse_date_delta( ) -> t.Callable[[t.Sequence], E]: def inner_func(args: t.Sequence) -> E: unit_based = len(args) == 3 - this = seq_get(args, 2) if unit_based else seq_get(args, 0) - expression = seq_get(args, 1) if unit_based else seq_get(args, 1) - unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit # type: ignore - return exp_class(this=this, expression=expression, unit=unit) + this = args[2] if unit_based else seq_get(args, 0) + unit = args[0] if unit_based else exp.Literal.string("DAY") + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + return exp_class(this=this, expression=seq_get(args, 1), unit=unit) return inner_func +def parse_date_delta_with_interval( + expression_class: t.Type[E], +) -> t.Callable[[t.Sequence], t.Optional[E]]: + def func(args: t.Sequence) -> t.Optional[E]: + if len(args) < 2: + return None + + interval = args[1] + expression = interval.this + if expression and expression.is_string: + expression = exp.Literal.number(expression.this) + + return expression_class( + this=args[0], + expression=expression, + unit=exp.Literal.string(interval.text("unit")), + ) + + return func + + def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index a33aadc..d7e2d88 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -104,6 +104,9 @@ class Drill(Dialect): LOG_DEFAULTS_TO_LN = True class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.INT: "INTEGER", @@ -120,6 +123,7 @@ class Drill(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TRANSFORMS = { diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index c034208..9454db6 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, datestrtodate_sql, format_time_lambda, + no_comment_column_constraint_sql, no_pivot_sql, no_properties_sql, no_safe_divide_sql, @@ -23,7 +24,7 @@ from sqlglot.tokens import TokenType def _ts_or_ds_add(self, expression): - this = expression.args.get("this") + this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" @@ -139,6 +140,8 @@ class DuckDB(Dialect): } class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { @@ -150,6 +153,7 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), + exp.CommentColumnConstraint: no_comment_column_constraint_sql, exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), @@ -213,6 +217,11 @@ class DuckDB(Dialect): "except": "EXCLUDE", } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + LIMIT_FETCH = "LIMIT" def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c39656e..6746fcf 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) - modified_increment = ( - int(expression.text("expression")) * multiplier - if expression.expression.is_number - else expression.expression - ) - modified_increment = exp.Literal.number(modified_increment) - return self.func(func, expression.this, modified_increment.this) + + if isinstance(expression, exp.DateSub): + multiplier *= -1 + + if expression.expression.is_number: + modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) + else: + modified_increment = expression.expression + if multiplier != 1: + modified_increment = exp.Mul( # type: ignore + this=modified_increment, expression=exp.Literal.number(multiplier) + ) + + return self.func(func, expression.this, modified_increment) def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: @@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str return f"TO_DATE({this})" -def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str: - unnest = expression.this - if isinstance(unnest, exp.Unnest): - alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode - return "".join( - self.sql( - exp.Lateral( - this=udtf(this=expression), - view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore - ) - ) - for expression, column in zip(unnest.expressions, alias.columns if alias else []) - ) - return self.join_sql(expression) - - def _index_sql(self: generator.Generator, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") @@ -195,6 +184,7 @@ class Hive(Dialect): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] ENCODE = "utf-8" + IDENTIFIER_CAN_START_WITH_DIGIT = True KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -217,9 +207,8 @@ class Hive(Dialect): "BD": "DECIMAL", } - IDENTIFIER_CAN_START_WITH_DIGIT = True - class Parser(parser.Parser): + LOG_DEFAULTS_TO_LN = True STRICT_CAST = False FUNCTIONS = { @@ -273,9 +262,13 @@ class Hive(Dialect): ), } - LOG_DEFAULTS_TO_LN = True - class Generator(generator.Generator): + LIMIT_FETCH = "LIMIT" + TABLESAMPLE_WITH_METHOD = False + TABLESAMPLE_SIZE_IS_PERCENT = True + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", @@ -289,6 +282,9 @@ class Hive(Dialect): **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.Select: transforms.preprocess( + [transforms.eliminate_qualify, transforms.unnest_to_explode] + ), exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), @@ -298,13 +294,13 @@ class Hive(Dialect): exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateSub: _add_date_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}", + exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, - exp.Join: _unnest_to_explode_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONFormat: rename_func("TO_JSON"), @@ -354,10 +350,9 @@ class Hive(Dialect): exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "LIMIT" - def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", @@ -378,4 +373,5 @@ class Hive(Dialect): expression = exp.DataType.build("text") elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index d64efbf..666e740 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -4,6 +4,8 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, + datestrtodate_sql, + format_time_lambda, locate_to_strposition, max_or_greatest, min_or_least, @@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import ( no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + parse_date_delta_with_interval, rename_func, strposition_to_locate_sql, ) @@ -76,18 +79,6 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add(expression_class): - def func(args): - interval = seq_get(args, 1) - return expression_class( - this=seq_get(args, 0), - expression=interval.this, - unit=exp.Literal.string(interval.text("unit").lower()), - ) - - return func - - def _date_add_sql(kind): def func(self, expression): this = self.sql(expression, "this") @@ -115,6 +106,7 @@ class MySQL(Dialect): "%k": "%-H", "%l": "%-I", "%T": "%H:%M:%S", + "%W": "%a", } class Tokenizer(tokens.Tokenizer): @@ -127,12 +119,13 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "CHARSET": TokenType.CHARACTER_SET, + "LONGBLOB": TokenType.LONGBLOB, "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, - "LONGBLOB": TokenType.LONGBLOB, - "START": TokenType.BEGIN, + "MEDIUMTEXT": TokenType.MEDIUMTEXT, "SEPARATOR": TokenType.SEPARATOR, + "START": TokenType.BEGIN, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -186,14 +179,15 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "DATE_ADD": _date_add(exp.DateAdd), - "DATE_SUB": _date_add(exp.DateSub), - "STR_TO_DATE": _str_to_date, - "LOCATE": locate_to_strposition, + "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), + "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), "LEFT": lambda args: exp.Substring( this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1) ), + "LOCATE": locate_to_strposition, + "STR_TO_DATE": _str_to_date, } FUNCTION_PARSERS = { @@ -388,32 +382,36 @@ class MySQL(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False + JOIN_HINTS = False + TABLE_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.CurrentDate: no_paren_current_date_sql, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.ILike: no_ilike_sql, - exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.Max: max_or_greatest, - exp.Min: min_or_least, - exp.TableSample: no_tablesample_sql, - exp.TryCast: no_trycast_sql, + exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.DateAdd: _date_add_sql("ADD"), - exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), exp.DateTrunc: _date_trunc_sql, - exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", - exp.StrToDate: _str_to_date_sql, - exp.StrToTime: _str_to_date_sql, - exp.Trim: _trim_sql, + exp.ILike: no_ilike_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.Max: max_or_greatest, + exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.StrPosition: strposition_to_locate_sql, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_date_sql, + exp.TableSample: no_tablesample_sql, + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), + exp.Trim: _trim_sql, + exp.TryCast: no_trycast_sql, + exp.WeekOfYear: rename_func("WEEKOFYEAR"), } TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() @@ -425,6 +423,7 @@ class MySQL(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } LIMIT_FETCH = "LIMIT" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 3819b76..9ccd02e 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -7,11 +7,6 @@ from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sq from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { - TokenType.COLUMN, - TokenType.RETURNING, -} - def _parse_xml_table(self) -> exp.XMLTable: this = self._parse_string() @@ -22,9 +17,7 @@ def _parse_xml_table(self) -> exp.XMLTable: if self._match_text_seq("PASSING"): # The BY VALUE keywords are optional and are provided for semantic clarity self._match_text_seq("BY", "VALUE") - passing = self._parse_csv( - lambda: self._parse_table(alias_tokens=PASSING_TABLE_ALIAS_TOKENS) - ) + passing = self._parse_csv(self._parse_column) by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") @@ -68,6 +61,8 @@ class Oracle(Dialect): } class Parser(parser.Parser): + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} + FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), @@ -78,6 +73,12 @@ class Oracle(Dialect): "XMLTABLE": _parse_xml_table, } + TYPE_LITERAL_PARSERS = { + exp.DataType.Type.DATE: lambda self, this, _: self.expression( + exp.DateStrToDate, this=this + ) + } + def _parse_column(self) -> t.Optional[exp.Expression]: column = super()._parse_column() if column: @@ -100,6 +101,8 @@ class Oracle(Dialect): class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True + JOIN_HINTS = False + TABLE_HINTS = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -119,6 +122,9 @@ class Oracle(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + exp.DateStrToDate: lambda self, e: self.func( + "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") + ), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -129,6 +135,12 @@ class Oracle(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", + exp.IfNull: rename_func("NVL"), + } + + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } LIMIT_FETCH = "FETCH" @@ -142,9 +154,9 @@ class Oracle(Dialect): def xmltable_sql(self, expression: exp.XMLTable) -> str: this = self.sql(expression, "this") - passing = self.expressions(expression, "passing") + passing = self.expressions(expression, key="passing") passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else "" - columns = self.expressions(expression, "columns") + columns = self.expressions(expression, key="columns") columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else "" by_ref = ( f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else "" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 31b7e45..c47ff51 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -5,6 +5,7 @@ from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + datestrtodate_sql, format_time_lambda, max_or_greatest, min_or_least, @@ -19,7 +20,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType -from sqlglot.transforms import delegate, preprocess +from sqlglot.transforms import preprocess, remove_target_from_merge DATE_DIFF_FACTOR = { "MICROSECOND": " * 1000000", @@ -239,7 +240,6 @@ class Postgres(Dialect): "SERIAL": TokenType.SERIAL, "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, - "UUID": TokenType.UUID, "CSTRING": TokenType.PSEUDO_TYPE, } @@ -248,18 +248,25 @@ class Postgres(Dialect): "$": TokenType.PARAMETER, } + VAR_SINGLE_TOKENS = {"$"} + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore - "NOW": exp.CurrentTimestamp.from_arg_list, - "TO_TIMESTAMP": _to_timestamp, - "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), - "GENERATE_SERIES": _generate_series, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), + "GENERATE_SERIES": _generate_series, + "NOW": exp.CurrentTimestamp.from_arg_list, + "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + "TO_TIMESTAMP": _to_timestamp, + } + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "DATE_PART": lambda self: self._parse_date_part(), } BITWISE = { @@ -279,8 +286,21 @@ class Postgres(Dialect): TokenType.LT_AT: binary_range_parser(exp.ArrayContained), } + def _parse_date_part(self) -> exp.Expression: + part = self._parse_type() + self._match(TokenType.COMMA) + value = self._parse_bitwise() + + if part and part.is_string: + part = exp.Var(this=part.name) + + return self.expression(exp.Extract, this=part, expression=value) + class Generator(generator.Generator): + INTERVAL_ALLOWS_PLURAL_FORM = False LOCKING_READS_SUPPORTED = True + JOIN_HINTS = False + TABLE_HINTS = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { @@ -301,7 +321,6 @@ class Postgres(Dialect): _auto_increment_to_serial, _serial_to_generated, ], - delegate("columndef_sql"), ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, @@ -312,6 +331,7 @@ class Postgres(Dialect): exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), @@ -321,6 +341,7 @@ class Postgres(Dialect): exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), + exp.Merge: preprocess([remove_target_from_merge]), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, @@ -344,4 +365,5 @@ class Postgres(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 07e8f43..489d439 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -19,20 +21,20 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self, expression): +def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str: accuracy = expression.args.get("accuracy") accuracy = ", " + self.sql(accuracy) if accuracy else "" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: sql = self.datatype_sql(expression) if expression.this == exp.DataType.Type.TIMESTAMPTZ: sql = f"{sql} WITH TIME ZONE" return sql -def _explode_to_unnest_sql(self, expression): +def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): return self.sql( exp.Join( @@ -47,22 +49,22 @@ def _explode_to_unnest_sql(self, expression): return self.lateral_sql(expression) -def _initcap_sql(self, expression): +def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: regex = r"(\w)(\w*)" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _decode_sql(self, expression): - _ensure_utf8(expression.args.get("charset")) +def _decode_sql(self: generator.Generator, expression: exp.Decode) -> str: + _ensure_utf8(expression.args["charset"]) return self.func("FROM_UTF8", expression.this, expression.args.get("replace")) -def _encode_sql(self, expression): - _ensure_utf8(expression.args.get("charset")) +def _encode_sql(self: generator.Generator, expression: exp.Encode) -> str: + _ensure_utf8(expression.args["charset"]) return f"TO_UTF8({self.sql(expression, 'this')})" -def _no_sort_array(self, expression): +def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: @@ -70,49 +72,62 @@ def _no_sort_array(self, expression): return self.func("ARRAY_SORT", expression.this, comparator) -def _schema_sql(self, expression): +def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: if isinstance(expression.parent, exp.Property): columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" - for schema in expression.parent.find_all(exp.Schema): - if isinstance(schema.parent, exp.Property): - expression = expression.copy() - expression.expressions.extend(schema.expressions) + if expression.parent: + for schema in expression.parent.find_all(exp.Schema): + if isinstance(schema.parent, exp.Property): + expression = expression.copy() + expression.expressions.extend(schema.expressions) return self.schema_sql(expression) -def _quantile_sql(self, expression): +def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" -def _str_to_time_sql(self, expression): +def _str_to_time_sql( + self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate +) -> str: return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self, expression): +def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.time_format, Presto.date_format): return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" return f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" -def _ts_or_ds_add_sql(self, expression): +def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: + this = expression.this + + if not isinstance(this, exp.CurrentDate): + this = self.func( + "DATE_PARSE", + self.func( + "SUBSTR", + this if this.is_string else exp.cast(this, "VARCHAR"), + exp.Literal.number(1), + exp.Literal.number(10), + ), + Presto.date_format, + ) + return self.func( "DATE_ADD", exp.Literal.string(expression.text("unit") or "day"), expression.expression, - self.func( - "DATE_PARSE", - self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)), - Presto.date_format, - ), + this, ) -def _sequence_sql(self, expression): +def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: start = expression.args["start"] end = expression.args["end"] step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series @@ -135,12 +150,12 @@ def _sequence_sql(self, expression): return self.func("SEQUENCE", start, end, step) -def _ensure_utf8(charset): +def _ensure_utf8(charset: exp.Literal) -> None: if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") -def _approx_percentile(args): +def _approx_percentile(args: t.Sequence) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -157,7 +172,7 @@ def _approx_percentile(args): return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args): +def _from_unixtime(args: t.Sequence) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -226,11 +241,15 @@ class Presto(Dialect): FUNCTION_PARSERS.pop("TRIM") class Generator(generator.Generator): + INTERVAL_ALLOWS_PLURAL_FORM = False + JOIN_HINTS = False + TABLE_HINTS = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TYPE_MAPPING = { @@ -246,7 +265,6 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -284,6 +302,9 @@ class Presto(Dialect): exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, + exp.Select: transforms.preprocess( + [transforms.eliminate_qualify, transforms.explode_to_unnest] + ), exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", @@ -308,7 +329,13 @@ class Presto(Dialect): exp.VariancePop: rename_func("VAR_POP"), } - def transaction_sql(self, expression): + def interval_sql(self, expression: exp.Interval) -> str: + unit = self.sql(expression, "unit") + if expression.this and unit.lower().startswith("week"): + return f"({expression.this.name} * INTERVAL '7' day)" + return super().interval_sql(expression) + + def transaction_sql(self, expression: exp.Transaction) -> str: modes = expression.args.get("modes") modes = f" {', '.join(modes)}" if modes else "" return f"START TRANSACTION{modes}" diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 63c14f4..a9c4f62 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -8,6 +8,10 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +def _json_sql(self, e) -> str: + return f'{self.sql(e, "this")}."{e.expression.name}"' + + class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { @@ -56,6 +60,7 @@ class Redshift(Postgres): "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, + "SYSDATE": TokenType.CURRENT_TIMESTAMP, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, "TOP": TokenType.TOP, @@ -63,7 +68,14 @@ class Redshift(Postgres): "VARBYTE": TokenType.VARBINARY, } + # Redshift allows # to appear as a table identifier prefix + SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy() + SINGLE_TOKENS.pop("#") + class Generator(Postgres.Generator): + LOCKING_READS_SUPPORTED = False + SINGLE_STRING_INTERVAL = True + TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BINARY: "VARBYTE", @@ -79,6 +91,7 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this ), @@ -87,12 +100,16 @@ class Redshift(Postgres): ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), + exp.JSONExtract: _json_sql, + exp.JSONExtractScalar: _json_sql, exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) TRANSFORMS.pop(exp.Pow) + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"} + def values_sql(self, expression: exp.Values) -> str: """ Converts `VALUES...` expression into a series of unions. diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 34bc3bd..0829669 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -23,14 +23,14 @@ from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType -def _check_int(s): +def _check_int(s: str) -> bool: if s[0] in ("-", "+"): return s[1:].isdigit() return s.isdigit() # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args): +def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -69,7 +69,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime.from_arg_list(args) -def _unix_to_time_sql(self, expression): +def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -84,8 +84,12 @@ def _unix_to_time_sql(self, expression): # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self): +def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: this = self._parse_var() or self._parse_type() + + if not this: + return None + self._match(TokenType.COMMA) expression = self._parse_bitwise() @@ -101,7 +105,7 @@ def _parse_date_part(self): scale = None ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix = self.expression(exp.TimeToUnix, this=ts) + to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) if scale: to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) @@ -112,7 +116,7 @@ def _parse_date_part(self): # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args): +def _div0_to_if(args: t.Sequence) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -120,18 +124,18 @@ def _div0_to_if(args): # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args): +def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args): +def _nullifzero_to_if(args: t.Sequence) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return "ARRAY" elif expression.this == exp.DataType.Type.MAP: @@ -155,9 +159,8 @@ class Snowflake(Dialect): "MM": "%m", "mm": "%m", "DD": "%d", - "dd": "%d", - "d": "%-d", - "DY": "%w", + "dd": "%-d", + "DY": "%a", "dy": "%w", "HH24": "%H", "hh24": "%H", @@ -174,6 +177,8 @@ class Snowflake(Dialect): } class Parser(parser.Parser): + QUOTED_PIVOT_COLUMNS = True + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, @@ -269,9 +274,14 @@ class Snowflake(Dialect): "$": TokenType.PARAMETER, } + VAR_SINGLE_TOKENS = {"$"} + class Generator(generator.Generator): PARAMETER_TOKEN = "$" MATCHED_BY_SOURCE = False + SINGLE_STRING_INTERVAL = True + JOIN_HINTS = False + TABLE_HINTS = False TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -287,26 +297,30 @@ class Snowflake(Dialect): ), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, + exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.If: rename_func("IFF"), - exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.LogicalOr: rename_func("BOOLOR_AGG"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), - exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.LogicalOr: rename_func("BOOLOR_AGG"), + exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.Max: max_or_greatest, + exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), + exp.TimeToStr: lambda self, e: self.func( + "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) + ), + exp.TimestampTrunc: timestamptrunc_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.Max: max_or_greatest, - exp.Min: min_or_least, + exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), } TYPE_MAPPING = { @@ -322,14 +336,15 @@ class Snowflake(Dialect): PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.SetProperty: exp.Properties.Location.UNSUPPORTED, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - def except_op(self, expression): + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") return super().except_op(expression) - def intersect_op(self, expression): + def intersect_op(self, expression: exp.Intersect) -> str: if not expression.args.get("distinct", False): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index c271f6f..a3e4cce 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,13 +1,15 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self, e): - kind = e.args.get("kind") +def _create_sql(self: Hive.Generator, e: exp.Create) -> str: + kind = e.args["kind"] properties = e.args.get("properties") if kind.upper() == "TABLE" and any( @@ -18,13 +20,13 @@ def _create_sql(self, e): return create_with_partitions_sql(self, e) -def _map_sql(self, expression): +def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: keys = self.sql(expression.args["keys"]) values = self.sql(expression.args["values"]) return f"MAP_FROM_ARRAYS({keys}, {values})" -def _str_to_date(self, expression): +def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.date_format: @@ -32,7 +34,7 @@ def _str_to_date(self, expression): return f"TO_DATE({this}, {time_format})" -def _unix_to_time(self, expression): +def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale is None: @@ -75,7 +77,11 @@ class Spark(Hive): length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "BOOLEAN": lambda args: exp.Cast( + this=seq_get(args, 0), to=exp.DataType.build("boolean") + ), "IIF": exp.If.from_arg_list, + "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")), "AGGREGATE": exp.Reduce.from_arg_list, "DAYOFWEEK": lambda args: exp.DayOfWeek( this=exp.TsOrDsToDate(this=seq_get(args, 0)), @@ -89,11 +95,16 @@ class Spark(Hive): "WEEKOFYEAR": lambda args: exp.WeekOfYear( this=exp.TsOrDsToDate(this=seq_get(args, 0)), ), + "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)), ), + "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "TIMESTAMP": lambda args: exp.Cast( + this=seq_get(args, 0), to=exp.DataType.build("timestamp") + ), } FUNCTION_PARSERS = { @@ -108,16 +119,43 @@ class Spark(Hive): "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), } - def _parse_add_column(self): + def _parse_add_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - def _parse_drop_column(self): + def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match_text_seq("DROP", "COLUMNS") and self.expression( exp.Drop, this=self._parse_schema(), kind="COLUMNS", ) + def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: + # Spark doesn't add a suffix to the pivot columns when there's a single aggregation + if len(pivot_columns) == 1: + return [""] + + names = [] + for agg in pivot_columns: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + + Moreover, function names are lowercased in order to mimic Spark's naming scheme. + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) + + return names + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, # type: ignore @@ -145,7 +183,7 @@ class Spark(Hive): exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, + exp.UnixToTime: _unix_to_time_sql, exp.Create: _create_sql, exp.Map: _map_sql, exp.Reduce: rename_func("AGGREGATE"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 4091dbb..4437f82 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -16,7 +16,7 @@ from sqlglot.tokens import TokenType def _date_add_sql(self, expression): modifier = expression.expression - modifier = expression.name if modifier.is_string else self.sql(modifier) + modifier = modifier.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" return self.func("DATE", expression.this, modifier) @@ -38,6 +38,9 @@ class SQLite(Dialect): } class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "INTEGER", @@ -82,6 +85,11 @@ class SQLite(Dialect): exp.TryCast: no_trycast_sql, } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + LIMIT_FETCH = "LIMIT" def cast_sql(self, expression: exp.Cast) -> str: diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 2ba1a92..ff19dab 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,7 +1,11 @@ from __future__ import annotations from sqlglot import exp -from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func +from sqlglot.dialects.dialect import ( + approx_count_distinct_sql, + arrow_json_extract_sql, + rename_func, +) from sqlglot.dialects.mysql import MySQL from sqlglot.helper import seq_get @@ -10,6 +14,7 @@ class StarRocks(MySQL): class Parser(MySQL.Parser): # type: ignore FUNCTIONS = { **MySQL.Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -25,6 +30,7 @@ class StarRocks(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, # type: ignore + exp.ApproxDistinct: approx_count_distinct_sql, exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 31b1c8d..792c2b4 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -21,6 +21,9 @@ def _count_sql(self, expression): class Tableau(Dialect): class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.If: _if_sql, @@ -28,6 +31,11 @@ class Tableau(Dialect): exp.Count: _count_sql, } + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 3d43793..331e105 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,14 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + max_or_greatest, + min_or_least, +) from sqlglot.tokens import TokenType @@ -115,7 +122,18 @@ class Teradata(Dialect): return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) + def _parse_cast(self, strict: bool) -> exp.Expression: + cast = t.cast(exp.Cast, super()._parse_cast(strict)) + if cast.to.this == exp.DataType.Type.DATE and self._match(TokenType.FORMAT): + return format_time_lambda(exp.TimeToStr, "teradata")( + [cast.this, self._parse_string()] + ) + return cast + class Generator(generator.Generator): + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", @@ -130,6 +148,7 @@ class Teradata(Dialect): **generator.Generator.TRANSFORMS, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b8a227b..9cf56e1 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -96,6 +96,23 @@ def _parse_eomonth(args): return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) +def _parse_hashbytes(args): + kind, data = args + kind = kind.name.upper() if kind.is_string else "" + + if kind == "MD5": + args.pop(0) + return exp.MD5(this=data) + if kind in ("SHA", "SHA1"): + args.pop(0) + return exp.SHA(this=data) + if kind == "SHA2_256": + return exp.SHA2(this=data, length=exp.Literal.number(256)) + if kind == "SHA2_512": + return exp.SHA2(this=data, length=exp.Literal.number(512)) + return exp.func("HASHBYTES", *args) + + def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return self.func(func, e.text("unit"), e.expression, e.this) @@ -266,6 +283,7 @@ class TSQL(Dialect): "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, + "SYSTEM_USER": TokenType.CURRENT_USER, } # TSQL allows @, # to appear as a variable/identifier prefix @@ -287,6 +305,7 @@ class TSQL(Dialect): "EOMONTH": _parse_eomonth, "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, + "HASHBYTES": _parse_hashbytes, "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, @@ -296,6 +315,14 @@ class TSQL(Dialect): "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, + "SYSTEM_USER": exp.CurrentUser.from_arg_list, + } + + JOIN_HINTS = { + "LOOP", + "HASH", + "MERGE", + "REMOTE", } VAR_LENGTH_DATATYPES = { @@ -441,11 +468,21 @@ class TSQL(Dialect): exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, exp.Max: max_or_greatest, + exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, + exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), + exp.SHA2: lambda self, e: self.func( + "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this + ), } TRANSFORMS.pop(exp.ReturnsProperty) + PROPERTIES_LOCATION = { + **generator.Generator.PROPERTIES_LOCATION, # type: ignore + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + } + LIMIT_FETCH = "FETCH" def offset_sql(self, expression: exp.Offset) -> str: diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 9011dce..49d3ff6 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -701,6 +701,119 @@ class Condition(Expression): """ return not_(self) + def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E: + this = self + other = convert(other) + if not isinstance(this, klass) and not isinstance(other, klass): + this = _wrap(this, Binary) + other = _wrap(other, Binary) + if reverse: + return klass(this=other, expression=this) + return klass(this=this, expression=other) + + def __getitem__(self, other: ExpOrStr | slice | t.Tuple[ExpOrStr]): + if isinstance(other, slice): + return Between( + this=self, + low=convert(other.start), + high=convert(other.stop), + ) + return Bracket(this=self, expressions=[convert(e) for e in ensure_list(other)]) + + def isin(self, *expressions: ExpOrStr, query: t.Optional[ExpOrStr] = None, **opts) -> In: + return In( + this=self, + expressions=[convert(e) for e in expressions], + query=maybe_parse(query, **opts) if query else None, + ) + + def like(self, other: ExpOrStr) -> Like: + return self._binop(Like, other) + + def ilike(self, other: ExpOrStr) -> ILike: + return self._binop(ILike, other) + + def eq(self, other: ExpOrStr) -> EQ: + return self._binop(EQ, other) + + def neq(self, other: ExpOrStr) -> NEQ: + return self._binop(NEQ, other) + + def rlike(self, other: ExpOrStr) -> RegexpLike: + return self._binop(RegexpLike, other) + + def __lt__(self, other: ExpOrStr) -> LT: + return self._binop(LT, other) + + def __le__(self, other: ExpOrStr) -> LTE: + return self._binop(LTE, other) + + def __gt__(self, other: ExpOrStr) -> GT: + return self._binop(GT, other) + + def __ge__(self, other: ExpOrStr) -> GTE: + return self._binop(GTE, other) + + def __add__(self, other: ExpOrStr) -> Add: + return self._binop(Add, other) + + def __radd__(self, other: ExpOrStr) -> Add: + return self._binop(Add, other, reverse=True) + + def __sub__(self, other: ExpOrStr) -> Sub: + return self._binop(Sub, other) + + def __rsub__(self, other: ExpOrStr) -> Sub: + return self._binop(Sub, other, reverse=True) + + def __mul__(self, other: ExpOrStr) -> Mul: + return self._binop(Mul, other) + + def __rmul__(self, other: ExpOrStr) -> Mul: + return self._binop(Mul, other, reverse=True) + + def __truediv__(self, other: ExpOrStr) -> Div: + return self._binop(Div, other) + + def __rtruediv__(self, other: ExpOrStr) -> Div: + return self._binop(Div, other, reverse=True) + + def __floordiv__(self, other: ExpOrStr) -> IntDiv: + return self._binop(IntDiv, other) + + def __rfloordiv__(self, other: ExpOrStr) -> IntDiv: + return self._binop(IntDiv, other, reverse=True) + + def __mod__(self, other: ExpOrStr) -> Mod: + return self._binop(Mod, other) + + def __rmod__(self, other: ExpOrStr) -> Mod: + return self._binop(Mod, other, reverse=True) + + def __pow__(self, other: ExpOrStr) -> Pow: + return self._binop(Pow, other) + + def __rpow__(self, other: ExpOrStr) -> Pow: + return self._binop(Pow, other, reverse=True) + + def __and__(self, other: ExpOrStr) -> And: + return self._binop(And, other) + + def __rand__(self, other: ExpOrStr) -> And: + return self._binop(And, other, reverse=True) + + def __or__(self, other: ExpOrStr) -> Or: + return self._binop(Or, other) + + def __ror__(self, other: ExpOrStr) -> Or: + return self._binop(Or, other, reverse=True) + + def __neg__(self) -> Neg: + return Neg(this=_wrap(self, Binary)) + + def __invert__(self) -> Not: + return not_(self) + class Predicate(Condition): """Relationships like x = y, x > 1, x >= y.""" @@ -818,7 +931,6 @@ class Create(Expression): "properties": False, "replace": False, "unique": False, - "volatile": False, "indexes": False, "no_schema_binding": False, "begin": False, @@ -1053,6 +1165,11 @@ class NotNullColumnConstraint(ColumnConstraintKind): arg_types = {"allow_null": False} +# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html +class OnUpdateColumnConstraint(ColumnConstraintKind): + pass + + class PrimaryKeyColumnConstraint(ColumnConstraintKind): arg_types = {"desc": False} @@ -1197,6 +1314,7 @@ class Drop(Expression): "materialized": False, "cascade": False, "constraints": False, + "purge": False, } @@ -1287,6 +1405,7 @@ class Insert(Expression): "with": False, "this": True, "expression": False, + "conflict": False, "returning": False, "overwrite": False, "exists": False, @@ -1295,6 +1414,16 @@ class Insert(Expression): } +class OnConflict(Expression): + arg_types = { + "duplicate": False, + "expressions": False, + "nothing": False, + "key": False, + "constraint": False, + } + + class Returning(Expression): arg_types = {"expressions": True} @@ -1326,7 +1455,12 @@ class Partition(Expression): class Fetch(Expression): - arg_types = {"direction": False, "count": False} + arg_types = { + "direction": False, + "count": False, + "percent": False, + "with_ties": False, + } class Group(Expression): @@ -1374,6 +1508,7 @@ class Join(Expression): "kind": False, "using": False, "natural": False, + "hint": False, } @property @@ -1384,6 +1519,10 @@ class Join(Expression): def side(self): return self.text("side").upper() + @property + def hint(self): + return self.text("hint").upper() + @property def alias_or_name(self): return self.this.alias_or_name @@ -1475,6 +1614,7 @@ class MatchRecognize(Expression): "after": False, "pattern": False, "define": False, + "alias": False, } @@ -1582,6 +1722,10 @@ class FreespaceProperty(Property): arg_types = {"this": True, "percent": False} +class InputOutputFormat(Expression): + arg_types = {"input_format": False, "output_format": False} + + class IsolatedLoadingProperty(Property): arg_types = { "no": True, @@ -1646,6 +1790,10 @@ class ReturnsProperty(Property): arg_types = {"this": True, "is_table": False, "table": False} +class RowFormatProperty(Property): + arg_types = {"this": True} + + class RowFormatDelimitedProperty(Property): # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml arg_types = { @@ -1683,6 +1831,10 @@ class SqlSecurityProperty(Property): arg_types = {"definer": True} +class StabilityProperty(Property): + arg_types = {"this": True} + + class TableFormatProperty(Property): arg_types = {"this": True} @@ -1695,8 +1847,8 @@ class TransientProperty(Property): arg_types = {"this": False} -class VolatilityProperty(Property): - arg_types = {"this": True} +class VolatileProperty(Property): + arg_types = {"this": False} class WithDataProperty(Property): @@ -1726,6 +1878,7 @@ class Properties(Expression): "LOCATION": LocationProperty, "PARTITIONED_BY": PartitionedByProperty, "RETURNS": ReturnsProperty, + "ROW_FORMAT": RowFormatProperty, "SORTKEY": SortKeyProperty, "TABLE_FORMAT": TableFormatProperty, } @@ -2721,6 +2874,7 @@ class Pivot(Expression): "expressions": True, "field": True, "unpivot": True, + "columns": False, } @@ -2731,6 +2885,8 @@ class Window(Expression): "order": False, "spec": False, "alias": False, + "over": False, + "first": False, } @@ -2816,6 +2972,7 @@ class DataType(Expression): FLOAT = auto() DOUBLE = auto() DECIMAL = auto() + BIGDECIMAL = auto() BIT = auto() BOOLEAN = auto() JSON = auto() @@ -2964,7 +3121,7 @@ class DropPartition(Expression): # Binary expressions like (ADD a b) -class Binary(Expression): +class Binary(Condition): arg_types = {"this": True, "expression": True} @property @@ -2980,7 +3137,7 @@ class Add(Binary): pass -class Connector(Binary, Condition): +class Connector(Binary): pass @@ -3142,7 +3299,7 @@ class ArrayOverlaps(Binary): # Unary Expressions # (NOT a) -class Unary(Expression): +class Unary(Condition): pass @@ -3150,11 +3307,11 @@ class BitwiseNot(Unary): pass -class Not(Unary, Condition): +class Not(Unary): pass -class Paren(Unary, Condition): +class Paren(Unary): arg_types = {"this": True, "with": False} @@ -3162,7 +3319,6 @@ class Neg(Unary): pass -# Special Functions class Alias(Expression): arg_types = {"this": True, "alias": False} @@ -3381,6 +3537,16 @@ class AnyValue(AggFunc): class Case(Func): arg_types = {"this": False, "ifs": True, "default": False} + def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case: + this = self.copy() if copy else self + this.append("ifs", If(this=maybe_parse(condition, **opts), true=maybe_parse(then, **opts))) + return this + + def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: + this = self.copy() if copy else self + this.set("default", maybe_parse(condition, **opts)) + return this + class Cast(Func): arg_types = {"this": True, "to": True} @@ -3719,6 +3885,10 @@ class Map(Func): arg_types = {"keys": False, "values": False} +class StarMap(Func): + pass + + class VarMap(Func): arg_types = {"keys": True, "values": True} is_var_len_args = True @@ -3734,6 +3904,10 @@ class Max(AggFunc): is_var_len_args = True +class MD5(Func): + _sql_names = ["MD5"] + + class Min(AggFunc): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -3840,6 +4014,15 @@ class SetAgg(AggFunc): pass +class SHA(Func): + _sql_names = ["SHA", "SHA1"] + + +class SHA2(Func): + _sql_names = ["SHA2"] + arg_types = {"this": True, "length": False} + + class SortArray(Func): arg_types = {"this": True, "asc": False} @@ -4017,6 +4200,12 @@ class When(Func): arg_types = {"matched": True, "source": False, "condition": False, "then": True} +# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html +# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16 +class NextValueFor(Func): + arg_types = {"this": True, "order": False} + + def _norm_arg(arg): return arg.lower() if type(arg) is str else arg @@ -4025,6 +4214,32 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) # Helpers +@t.overload +def maybe_parse( + sql_or_expression: ExpOrStr, + *, + into: t.Type[E], + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> E: + ... + + +@t.overload +def maybe_parse( + sql_or_expression: str | E, + *, + into: t.Optional[IntoType] = None, + dialect: DialectType = None, + prefix: t.Optional[str] = None, + copy: bool = False, + **opts, +) -> E: + ... + + def maybe_parse( sql_or_expression: ExpOrStr, *, @@ -4200,15 +4415,15 @@ def _combine(expressions, operator, dialect=None, **opts): expressions = [condition(expression, dialect=dialect, **opts) for expression in expressions] this = expressions[0] if expressions[1:]: - this = _wrap_operator(this) + this = _wrap(this, Connector) for expression in expressions[1:]: - this = operator(this=this, expression=_wrap_operator(expression)) + this = operator(this=this, expression=_wrap(expression, Connector)) return this -def _wrap_operator(expression): - if isinstance(expression, (And, Or, Not)): - expression = Paren(this=expression) +def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: + if isinstance(expression, kind): + return Paren(this=expression) return expression @@ -4506,7 +4721,7 @@ def not_(expression, dialect=None, **opts) -> Not: dialect=dialect, **opts, ) - return Not(this=_wrap_operator(this)) + return Not(this=_wrap(this, Connector)) def paren(expression) -> Paren: @@ -4657,6 +4872,8 @@ def alias_( if table: table_alias = TableAlias(this=alias) + + exp = exp.copy() if isinstance(expression, Expression) else exp exp.set("alias", table_alias) if not isinstance(table, bool): @@ -4864,16 +5081,22 @@ def convert(value) -> Expression: """ if isinstance(value, Expression): return value - if value is None: - return NULL - if isinstance(value, bool): - return Boolean(this=value) if isinstance(value, str): return Literal.string(value) - if isinstance(value, float) and math.isnan(value): + if isinstance(value, bool): + return Boolean(this=value) + if value is None or (isinstance(value, float) and math.isnan(value)): return NULL if isinstance(value, numbers.Number): return Literal.number(value) + if isinstance(value, datetime.datetime): + datetime_literal = Literal.string( + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + ) + return TimeStrToTime(this=datetime_literal) + if isinstance(value, datetime.date): + date_literal = Literal.string(value.strftime("%Y-%m-%d")) + return DateStrToDate(this=date_literal) if isinstance(value, tuple): return Tuple(expressions=[convert(v) for v in value]) if isinstance(value, list): @@ -4883,14 +5106,6 @@ def convert(value) -> Expression: keys=[convert(k) for k in value], values=[convert(v) for v in value.values()], ) - if isinstance(value, datetime.datetime): - datetime_literal = Literal.string( - (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() - ) - return TimeStrToTime(this=datetime_literal) - if isinstance(value, datetime.date): - date_literal = Literal.string(value.strftime("%Y-%m-%d")) - return DateStrToDate(this=date_literal) raise ValueError(f"Cannot convert {value}") @@ -5030,7 +5245,9 @@ def replace_placeholders(expression, *args, **kwargs): return expression.transform(_replace_placeholders, iter(args), **kwargs) -def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True) -> Expression: +def expand( + expression: Expression, sources: t.Dict[str, Subqueryable], copy: bool = True +) -> Expression: """Transforms an expression by expanding all referenced sources into subqueries. Examples: @@ -5038,6 +5255,9 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql() 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */' + >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql() + 'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */' + Args: expression: The expression to expand. sources: A dictionary of name to Subqueryables. @@ -5054,7 +5274,7 @@ def expand(expression: Expression, sources: t.Dict[str, Subqueryable], copy=True if source: subquery = source.subquery(node.alias or name) subquery.comments = [f"source: {name}"] - return subquery + return subquery.transform(_expand, copy=False) return node return expression.transform(_expand, copy=copy) @@ -5089,8 +5309,8 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: from sqlglot.dialects.dialect import Dialect - converted = [convert(arg) for arg in args] - kwargs = {key: convert(value) for key, value in kwargs.items()} + converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect) for arg in args] + kwargs = {key: maybe_parse(value, dialect=dialect) for key, value in kwargs.items()} parser = Dialect.get_or_raise(dialect)().parser() from_args_list = parser.FUNCTIONS.get(name.upper()) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 8a49d55..bd12d54 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -76,11 +76,13 @@ class Generator: exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY", exp.TransientProperty: lambda self, e: "TRANSIENT", - exp.VolatilityProperty: lambda self, e: e.name, + exp.StabilityProperty: lambda self, e: e.name, + exp.VolatileProperty: lambda self, e: "VOLATILE", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", + exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE", exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", @@ -110,8 +112,19 @@ class Generator: # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed MATCHED_BY_SOURCE = True - # Whether or not limit and fetch are supported - # "ALL", "LIMIT", "FETCH" + # Whether or not the INTERVAL expression works only with values like '1 day' + SINGLE_STRING_INTERVAL = False + + # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs + INTERVAL_ALLOWS_PLURAL_FORM = True + + # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI + TABLESAMPLE_WITH_METHOD = True + + # Whether or not to treat the number in TABLESAMPLE (50) as a percentage + TABLESAMPLE_SIZE_IS_PERCENT = False + + # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" TYPE_MAPPING = { @@ -129,6 +142,18 @@ class Generator: "replace": "REPLACE", } + TIME_PART_SINGULARS = { + "microseconds": "microsecond", + "seconds": "second", + "minutes": "minute", + "hours": "hour", + "days": "day", + "weeks": "week", + "months": "month", + "quarters": "quarter", + "years": "year", + } + TOKEN_MAPPING: t.Dict[TokenType, str] = {} STRUCT_DELIMITER = ("<", ">") @@ -168,6 +193,7 @@ class Generator: exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, exp.Property: exp.Properties.Location.POST_WITH, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, + exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, @@ -175,15 +201,22 @@ class Generator: exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, + exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, exp.TableFormatProperty: exp.Properties.Location.POST_WITH, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, exp.TransientProperty: exp.Properties.Location.POST_CREATE, - exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, } - WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) + JOIN_HINTS = True + TABLE_HINTS = True + + RESERVED_KEYWORDS: t.Set[str] = set() + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) + UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column) + SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" __slots__ = ( @@ -322,10 +355,15 @@ class Generator: comment = comment + " " if comment[-1].strip() else comment return comment - def maybe_comment(self, sql: str, expression: exp.Expression) -> str: - comments = expression.comments if self._comments else None + def maybe_comment( + self, + sql: str, + expression: t.Optional[exp.Expression] = None, + comments: t.Optional[t.List[str]] = None, + ) -> str: + comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore - if not comments: + if not comments or isinstance(expression, exp.Binary): return sql sep = "\n" if self.pretty else " " @@ -621,7 +659,6 @@ class Generator: replace = " OR REPLACE" if expression.args.get("replace") else "" unique = " UNIQUE" if expression.args.get("unique") else "" - volatile = " VOLATILE" if expression.args.get("volatile") else "" postcreate_props_sql = "" if properties_locs.get(exp.Properties.Location.POST_CREATE): @@ -632,7 +669,7 @@ class Generator: wrapped=False, ) - modifiers = "".join((replace, unique, volatile, postcreate_props_sql)) + modifiers = "".join((replace, unique, postcreate_props_sql)) postexpression_props_sql = "" if properties_locs.get(exp.Properties.Location.POST_EXPRESSION): @@ -684,6 +721,9 @@ class Generator: def hexstring_sql(self, expression: exp.HexString) -> str: return self.sql(expression, "this") + def bytestring_sql(self, expression: exp.ByteString) -> str: + return self.sql(expression, "this") + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) @@ -695,9 +735,7 @@ class Generator: nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" if expression.args.get("values") is not None: delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") - values = ( - f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}" - ) + values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}" else: nested = f"({interior})" @@ -713,7 +751,7 @@ class Generator: this = self.sql(expression, "this") this = f" FROM {this}" if this else "" using_sql = ( - f" USING {self.expressions(expression, 'using', sep=', USING ')}" + f" USING {self.expressions(expression, key='using', sep=', USING ')}" if expression.args.get("using") else "" ) @@ -730,7 +768,10 @@ class Generator: materialized = " MATERIALIZED" if expression.args.get("materialized") else "" cascade = " CASCADE" if expression.args.get("cascade") else "" constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" - return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}" + purge = " PURGE" if expression.args.get("purge") else "" + return ( + f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}" + ) def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( @@ -746,7 +787,10 @@ class Generator: direction = f" {direction.upper()}" if direction else "" count = expression.args.get("count") count = f" {count}" if count else "" - return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" + if expression.args.get("percent"): + count = f"{count} PERCENT" + with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY" + return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}" def filter_sql(self, expression: exp.Filter) -> str: this = self.sql(expression, "this") @@ -766,12 +810,24 @@ class Generator: def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name - text = text.lower() if self.normalize and not expression.quoted else text + lower = text.lower() + text = lower if self.normalize and not expression.quoted else text text = text.replace(self.identifier_end, self._escaped_identifier_end) - if expression.quoted or should_identify(text, self.identify): + if ( + expression.quoted + or should_identify(text, self.identify) + or lower in self.RESERVED_KEYWORDS + ): text = f"{self.identifier_start}{text}{self.identifier_end}" return text + def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: + input_format = self.sql(expression, "input_format") + input_format = f"INPUTFORMAT {input_format}" if input_format else "" + output_format = self.sql(expression, "output_format") + output_format = f"OUTPUTFORMAT {output_format}" if output_format else "" + return self.sep().join((input_format, output_format)) + def national_sql(self, expression: exp.National) -> str: return f"N{self.sql(expression, 'this')}" @@ -984,9 +1040,10 @@ class Generator: self.sql(expression, "partition") if expression.args.get("partition") else "" ) expression_sql = self.sql(expression, "expression") + conflict = self.sql(expression, "conflict") returning = self.sql(expression, "returning") sep = self.sep() if partition_sql else "" - sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}" + sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}" return self.prepend_ctes(expression, sql) def intersect_sql(self, expression: exp.Intersect) -> str: @@ -1004,6 +1061,19 @@ class Generator: def pseudotype_sql(self, expression: exp.PseudoType) -> str: return expression.name.upper() + def onconflict_sql(self, expression: exp.OnConflict) -> str: + conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" + constraint = self.sql(expression, "constraint") + if constraint: + constraint = f"ON CONSTRAINT {constraint}" + key = self.expressions(expression, key="key", flat=True) + do = "" if expression.args.get("duplicate") else " DO " + nothing = "NOTHING" if expression.args.get("nothing") else "" + expressions = self.expressions(expression, flat=True) + if expressions: + expressions = f"UPDATE SET {expressions}" + return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}" + def returning_sql(self, expression: exp.Returning) -> str: return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}" @@ -1036,7 +1106,7 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" hints = self.expressions(expression, key="hints", sep=", ", flat=True) - hints = f" WITH ({hints})" if hints else "" + hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else "" laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") @@ -1053,7 +1123,7 @@ class Generator: this = self.sql(expression, "this") alias = "" method = self.sql(expression, "method") - method = f"{method.upper()} " if method else "" + method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else "" numerator = self.sql(expression, "bucket_numerator") denominator = self.sql(expression, "bucket_denominator") field = self.sql(expression, "bucket_field") @@ -1064,6 +1134,8 @@ class Generator: rows = self.sql(expression, "rows") rows = f"{rows} ROWS" if rows else "" size = self.sql(expression, "size") + if size and self.TABLESAMPLE_SIZE_IS_PERCENT: + size = f"{size} PERCENT" seed = self.sql(expression, "seed") seed = f" {seed_prefix} ({seed})" if seed else "" kind = expression.args.get("kind", "TABLESAMPLE") @@ -1154,6 +1226,7 @@ class Generator: "NATURAL" if expression.args.get("natural") else None, expression.side, expression.kind, + expression.hint if self.JOIN_HINTS else None, "JOIN", ) if op @@ -1311,16 +1384,20 @@ class Generator: def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: partition = self.partition_by_sql(expression) order = self.sql(expression, "order") - measures = self.sql(expression, "measures") - measures = self.seg(f"MEASURES {measures}") if measures else "" + measures = self.expressions(expression, key="measures") + measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else "" rows = self.sql(expression, "rows") rows = self.seg(rows) if rows else "" after = self.sql(expression, "after") after = self.seg(after) if after else "" pattern = self.sql(expression, "pattern") pattern = self.seg(f"PATTERN ({pattern})") if pattern else "" - define = self.sql(expression, "define") - define = self.seg(f"DEFINE {define}") if define else "" + definition_sqls = [ + f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}" + for definition in expression.args.get("define", []) + ] + definitions = self.expressions(sqls=definition_sqls) + define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else "" body = "".join( ( partition, @@ -1332,7 +1409,9 @@ class Generator: define, ) ) - return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}" + alias = self.sql(expression, "alias") + alias = f" {alias}" if alias else "" + return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: limit = expression.args.get("limit") @@ -1353,7 +1432,7 @@ class Generator: self.sql(expression, "group"), self.sql(expression, "having"), self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) + self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) if expression.args.get("windows") else "", self.sql(expression, "distribute"), @@ -1471,15 +1550,21 @@ class Generator: partition_sql = partition + " " if partition and order else partition spec = expression.args.get("spec") - spec_sql = " " + self.window_spec_sql(spec) if spec else "" + spec_sql = " " + self.windowspec_sql(spec) if spec else "" alias = self.sql(expression, "alias") - this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}" + over = self.sql(expression, "over") or "OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else over}" + + first = expression.args.get("first") + if first is not None: + first = " FIRST " if first else " LAST " + first = first or "" if not partition and not order and not spec and alias: return f"{this} {alias}" - window_args = alias + partition_sql + order_sql + spec_sql + window_args = alias + first + partition_sql + order_sql + spec_sql return f"{this} ({window_args.strip()})" @@ -1487,7 +1572,7 @@ class Generator: partition = self.expressions(expression, key="partition_by", flat=True) return f"PARTITION BY {partition}" if partition else "" - def window_spec_sql(self, expression: exp.WindowSpec) -> str: + def windowspec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") end = ( @@ -1508,7 +1593,7 @@ class Generator: return f"{this} BETWEEN {low} AND {high}" def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = apply_index_offset(expression.expressions, self.index_offset) + expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset) expressions_sql = ", ".join(self.sql(e) for e in expressions) return f"{self.sql(expression, 'this')}[{expressions_sql}]" @@ -1550,6 +1635,11 @@ class Generator: expressions = self.expressions(expression, flat=True) return f"CONSTRAINT {this} {expressions}" + def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str: + order = expression.args.get("order") + order = f" OVER ({self.order_sql(order, flat=True)})" if order else "" + return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}" + def extract_sql(self, expression: exp.Extract) -> str: this = self.sql(expression, "this") expression_sql = self.sql(expression, "expression") @@ -1586,7 +1676,7 @@ class Generator: def primarykey_sql(self, expression: exp.ForeignKey) -> str: expressions = self.expressions(expression, flat=True) - options = self.expressions(expression, "options", flat=True, sep=" ") + options = self.expressions(expression, key="options", flat=True, sep=" ") options = f" {options}" if options else "" return f"PRIMARY KEY ({expressions}){options}" @@ -1644,17 +1734,20 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: - this = expression.args.get("this") - if this: - this = ( - f" {this}" - if isinstance(this, exp.Literal) or isinstance(this, exp.Paren) - else f" ({this})" - ) - else: - this = "" unit = self.sql(expression, "unit") + if not self.INTERVAL_ALLOWS_PLURAL_FORM: + unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit) unit = f" {unit}" if unit else "" + + if self.SINGLE_STRING_INTERVAL: + this = expression.this.name if expression.this else "" + return f"INTERVAL '{this}{unit}'" + + this = self.sql(expression, "this") + if this: + unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES) + this = f" {this}" if unwrapped else f" ({this})" + return f"INTERVAL{this}{unit}" def return_sql(self, expression: exp.Return) -> str: @@ -1664,7 +1757,7 @@ class Generator: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) expressions = f"({expressions})" if expressions else "" - options = self.expressions(expression, "options", flat=True, sep=" ") + options = self.expressions(expression, key="options", flat=True, sep=" ") options = f" {options}" if options else "" return f"REFERENCES {this}{expressions}{options}" @@ -1690,9 +1783,9 @@ class Generator: return f"NOT {self.sql(expression, 'this')}" def alias_sql(self, expression: exp.Alias) -> str: - to_sql = self.sql(expression, "alias") - to_sql = f" AS {to_sql}" if to_sql else "" - return f"{self.sql(expression, 'this')}{to_sql}" + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + return f"{self.sql(expression, 'this')}{alias}" def aliases_sql(self, expression: exp.Aliases) -> str: return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" @@ -1712,7 +1805,11 @@ class Generator: if not self.pretty: return self.binary(expression, op) - sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False)) + sqls = tuple( + self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e) + for i, e in enumerate(expression.flatten(unnest=False)) + ) + sep = "\n" if self.text_width(sqls) > self._max_text_width else " " return f"{sep}{op} ".join(sqls) @@ -1797,13 +1894,13 @@ class Generator: actions = expression.args["actions"] if isinstance(actions[0], exp.ColumnDef): - actions = self.expressions(expression, "actions", prefix="ADD COLUMN ") + actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ") elif isinstance(actions[0], exp.Schema): - actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") + actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Delete): - actions = self.expressions(expression, "actions", flat=True) + actions = self.expressions(expression, key="actions", flat=True) else: - actions = self.expressions(expression, "actions") + actions = self.expressions(expression, key="actions") exists = " IF EXISTS" if expression.args.get("exists") else "" return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" @@ -1935,6 +2032,7 @@ class Generator: return f"USE{kind}{this}" def binary(self, expression: exp.Binary, op: str) -> str: + op = self.maybe_comment(op, comments=expression.comments) return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" def function_fallback_sql(self, expression: exp.Func) -> str: @@ -1965,14 +2063,15 @@ class Generator: def expressions( self, - expression: exp.Expression, + expression: t.Optional[exp.Expression] = None, key: t.Optional[str] = None, + sqls: t.Optional[t.List[str]] = None, flat: bool = False, indent: bool = True, sep: str = ", ", prefix: str = "", ) -> str: - expressions = expression.args.get(key or "expressions") + expressions = expression.args.get(key or "expressions") if expression else sqls if not expressions: return "" diff --git a/sqlglot/helper.py b/sqlglot/helper.py index d44d7dd..b2f0520 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -131,11 +131,16 @@ def subclasses( ] -def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]: +def apply_index_offset( + this: exp.Expression, + expressions: t.List[t.Optional[E]], + offset: int, +) -> t.List[t.Optional[E]]: """ Applies an offset to a given integer literal expression. Args: + this: the target of the index expressions: the expression the offset will be applied to, wrapped in a list. offset: the offset that will be applied. @@ -148,11 +153,28 @@ def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.Lis expression = expressions[0] - if expression and expression.is_int: - expression = expression.copy() - logger.warning("Applying array index offset (%s)", offset) - expression.args["this"] = str(int(expression.this) + offset) # type: ignore - return [expression] + from sqlglot import exp + from sqlglot.optimizer.annotate_types import annotate_types + from sqlglot.optimizer.simplify import simplify + + if not this.type: + annotate_types(this) + + if t.cast(exp.DataType, this.type).this not in ( + exp.DataType.Type.UNKNOWN, + exp.DataType.Type.ARRAY, + ): + return expressions + + if expression: + if not expression.type: + annotate_types(expression) + if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: + logger.warning("Applying array index offset (%s)", offset) + expression = simplify( + exp.Add(this=expression.copy(), expression=exp.Literal.number(offset)) + ) + return [expression] return expressions diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 2e563ae..0eac870 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -20,6 +20,7 @@ class Node: expression: exp.Expression source: exp.Expression downstream: t.List[Node] = field(default_factory=list) + alias: str = "" def walk(self) -> t.Iterator[Node]: yield self @@ -69,14 +70,19 @@ def lineage( optimized = optimize(expression, schema=schema, rules=rules) scope = build_scope(optimized) - tables: t.Dict[str, Node] = {} def to_node( column_name: str, scope: Scope, scope_name: t.Optional[str] = None, upstream: t.Optional[Node] = None, + alias: t.Optional[str] = None, ) -> Node: + aliases = { + dt.alias: dt.comments[0].split()[1] + for dt in scope.derived_tables + if dt.comments and dt.comments[0].startswith("source: ") + } if isinstance(scope.expression, exp.Union): for scope in scope.union_scopes: node = to_node( @@ -84,37 +90,58 @@ def lineage( scope=scope, scope_name=scope_name, upstream=upstream, + alias=aliases.get(scope_name), ) return node - select = next(select for select in scope.selects if select.alias_or_name == column_name) - source = optimize(scope.expression.select(select, append=False), schema=schema, rules=rules) - select = source.selects[0] + # Find the specific select clause that is the source of the column we want. + # This can either be a specific, named select or a generic `*` clause. + select = next( + (select for select in scope.selects if select.alias_or_name == column_name), + exp.Star() if scope.expression.is_star else None, + ) + if not select: + raise ValueError(f"Could not find {column_name} in {scope.expression}") + + if isinstance(scope.expression, exp.Select): + # For better ergonomics in our node labels, replace the full select with + # a version that has only the column we care about. + # "x", SELECT x, y FROM foo + # => "x", SELECT x FROM foo + source = optimize( + scope.expression.select(select, append=False), schema=schema, rules=rules + ) + select = source.selects[0] + else: + source = scope.expression + + # Create the node for this step in the lineage chain, and attach it to the previous one. node = Node( name=f"{scope_name}.{column_name}" if scope_name else column_name, source=source, expression=select, + alias=alias or "", ) - if upstream: upstream.downstream.append(node) + # Find all columns that went into creating this one to list their lineage nodes. for c in set(select.find_all(exp.Column)): table = c.table - source = scope.sources[table] + source = scope.sources.get(table) if isinstance(source, Scope): + # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. to_node( - c.name, - scope=source, - scope_name=table, - upstream=node, + c.name, scope=source, scope_name=table, upstream=node, alias=aliases.get(table) ) else: - if table not in tables: - tables[table] = Node(name=c.sql(), source=source, expression=source) - node.downstream.append(tables[table]) + # The source is not a scope - we've reached the end of the line. At this point, if a source is not found + # it means this column's lineage is unknown. This can happen if the definition of a source used in a query + # is not passed into the `sources` map. + source = source or exp.Placeholder() + node.downstream.append(Node(name=c.sql(), source=source, expression=source)) return node diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 99888c6..6238759 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -116,6 +116,9 @@ class TypeAnnotator: exp.ArrayConcat: lambda self, expr: self._annotate_with_type( expr, exp.DataType.Type.VARCHAR ), + exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), + exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), @@ -335,7 +338,7 @@ class TypeAnnotator: left_type = expression.left.type.this right_type = expression.right.type.this - if isinstance(expression, (exp.And, exp.Or)): + if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: expression.type = exp.DataType.Type.NULL elif exp.DataType.Type.NULL in (left_type, right_type): @@ -344,7 +347,7 @@ class TypeAnnotator: ) else: expression.type = exp.DataType.Type.BOOLEAN - elif isinstance(expression, (exp.Condition, exp.Predicate)): + elif isinstance(expression, exp.Predicate): expression.type = exp.DataType.Type.BOOLEAN else: expression.type = self._maybe_coerce(left_type, right_type) diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index f2df230..40668ef 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -46,7 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = root = node is expression original = node.copy() try: - node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + node = node.replace( + while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + ) except OptimizeError as e: logger.info(e) node.replace(original) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 6eae2b5..0a31246 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -93,6 +93,7 @@ def _expand_using(scope, resolver): if column not in columns: columns[column] = k + source_table = ordered[-1] ordered.append(join_table) join_columns = resolver.get_source_columns(join_table) conditions = [] @@ -102,8 +103,10 @@ def _expand_using(scope, resolver): table = columns.get(identifier) if not table or identifier not in join_columns: - raise OptimizeError(f"Cannot automatically join: {identifier}") + if columns and join_columns: + raise OptimizeError(f"Cannot automatically join: {identifier}") + table = table or source_table conditions.append( exp.condition( exp.EQ( diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 93e1179..a719ebe 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -65,5 +65,8 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): if not table_alias.name: table_alias.set("this", next_name()) + if isinstance(udtf, exp.Values) and not table_alias.columns: + for i, e in enumerate(udtf.expressions[0].expressions): + table_alias.append("columns", exp.to_identifier(f"_col_{i}")) return expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 28ae86d..4e6c910 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -201,23 +201,24 @@ def _simplify_comparison(expression, left, right, or_=False): return left if (av < bv if or_ else av >= bv) else right # we can't ever shortcut to true because the column could be null - if isinstance(a, exp.LT) and isinstance(b, GT_GTE): - if not or_ and av <= bv: - return exp.false() - elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): - if not or_ and av >= bv: - return exp.false() - elif isinstance(a, exp.EQ): - if isinstance(b, exp.LT): - return exp.false() if av >= bv else a - if isinstance(b, exp.LTE): - return exp.false() if av > bv else a - if isinstance(b, exp.GT): - return exp.false() if av <= bv else a - if isinstance(b, exp.GTE): - return exp.false() if av < bv else a - if isinstance(b, exp.NEQ): - return exp.false() if av == bv else a + if not or_: + if isinstance(a, exp.LT) and isinstance(b, GT_GTE): + if av <= bv: + return exp.false() + elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): + if av >= bv: + return exp.false() + elif isinstance(a, exp.EQ): + if isinstance(b, exp.LT): + return exp.false() if av >= bv else a + if isinstance(b, exp.LTE): + return exp.false() if av > bv else a + if isinstance(b, exp.GT): + return exp.false() if av <= bv else a + if isinstance(b, exp.GTE): + return exp.false() if av < bv else a + if isinstance(b, exp.NEQ): + return exp.false() if av == bv else a return None diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b3b899c..abb23ad 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -18,8 +18,13 @@ from sqlglot.trie import in_trie, new_trie logger = logging.getLogger("sqlglot") +E = t.TypeVar("E", bound=exp.Expression) + def parse_var_map(args: t.Sequence) -> exp.Expression: + if len(args) == 1 and args[0].is_star: + return exp.StarMap(this=args[0]) + keys = [] values = [] for i in range(0, len(args), 2): @@ -108,6 +113,8 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_USER: exp.CurrentUser, } + JOIN_HINTS: t.Set[str] = set() + NESTED_TYPE_TOKENS = { TokenType.ARRAY, TokenType.MAP, @@ -145,6 +152,7 @@ class Parser(metaclass=_Parser): TokenType.DATETIME, TokenType.DATE, TokenType.DECIMAL, + TokenType.BIGDECIMAL, TokenType.UUID, TokenType.GEOGRAPHY, TokenType.GEOMETRY, @@ -221,8 +229,10 @@ class Parser(metaclass=_Parser): TokenType.FORMAT, TokenType.FULL, TokenType.IF, + TokenType.IS, TokenType.ISNULL, TokenType.INTERVAL, + TokenType.KEEP, TokenType.LAZY, TokenType.LEADING, TokenType.LEFT, @@ -235,6 +245,7 @@ class Parser(metaclass=_Parser): TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, + TokenType.OVERWRITE, TokenType.PARTITION, TokenType.PERCENT, TokenType.PIVOT, @@ -266,6 +277,8 @@ class Parser(metaclass=_Parser): *NO_PAREN_FUNCTIONS, } + INTERVAL_VARS = ID_VAR_TOKENS - {TokenType.END} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { TokenType.APPLY, TokenType.FULL, @@ -276,6 +289,8 @@ class Parser(metaclass=_Parser): TokenType.WINDOW, } + COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS} + UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} @@ -400,7 +415,7 @@ class Parser(metaclass=_Parser): COLUMN_OPERATORS = { TokenType.DOT: None, TokenType.DCOLON: lambda self, this, to: self.expression( - exp.Cast, + exp.Cast if self.STRICT_CAST else exp.TryCast, this=this, to=to, ), @@ -560,7 +575,7 @@ class Parser(metaclass=_Parser): ), "DEFINER": lambda self: self._parse_definer(), "DETERMINISTIC": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), "DISTKEY": lambda self: self._parse_distkey(), "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), @@ -571,7 +586,7 @@ class Parser(metaclass=_Parser): "FREESPACE": lambda self: self._parse_freespace(), "GLOBAL": lambda self: self._parse_temporary(global_=True), "IMMUTABLE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE") + exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), "JOURNAL": lambda self: self._parse_journal( no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" @@ -600,20 +615,20 @@ class Parser(metaclass=_Parser): "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), + "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), "SET": lambda self: self.expression(exp.SetProperty, multi=False), "SORTKEY": lambda self: self._parse_sortkey(), "STABLE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("STABLE") + exp.StabilityProperty, this=exp.Literal.string("STABLE") ), - "STORED": lambda self: self._parse_property_assignment(exp.FileFormatProperty), + "STORED": lambda self: self._parse_stored(), "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), + "TEMP": lambda self: self._parse_temporary(global_=False), "TEMPORARY": lambda self: self._parse_temporary(global_=False), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), - "VOLATILE": lambda self: self.expression( - exp.VolatilityProperty, this=exp.Literal.string("VOLATILE") - ), + "VOLATILE": lambda self: self._parse_volatile_property(), "WITH": lambda self: self._parse_with_property(), } @@ -648,8 +663,11 @@ class Parser(metaclass=_Parser): "LIKE": lambda self: self._parse_create_like(), "NOT": lambda self: self._parse_not_constraint(), "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True), + "ON": lambda self: self._match(TokenType.UPDATE) + and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()), "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), "PRIMARY KEY": lambda self: self._parse_primary_key(), + "REFERENCES": lambda self: self._parse_references(match=False), "TITLE": lambda self: self.expression( exp.TitleColumnConstraint, this=self._parse_var_or_string() ), @@ -668,9 +686,14 @@ class Parser(metaclass=_Parser): SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"} NO_PAREN_FUNCTION_PARSERS = { + TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()), TokenType.CASE: lambda self: self._parse_case(), TokenType.IF: lambda self: self._parse_if(), - TokenType.ANY: lambda self: self.expression(exp.Any, this=self._parse_bitwise()), + TokenType.NEXT_VALUE_FOR: lambda self: self.expression( + exp.NextValueFor, + this=self._parse_column(), + order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order), + ), } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { @@ -715,6 +738,8 @@ class Parser(metaclass=_Parser): SHOW_PARSERS: t.Dict[str, t.Callable] = {} + TYPE_LITERAL_PARSERS: t.Dict[exp.DataType.Type, t.Callable] = {} + MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} @@ -731,6 +756,7 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} + WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -738,6 +764,9 @@ class Parser(metaclass=_Parser): CONVERT_TYPE_FIRST = False + QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None + PREFIXED_PIVOT_COLUMNS = False + LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False @@ -895,8 +924,8 @@ class Parser(metaclass=_Parser): error level setting. """ token = token or self._curr or self._prev or Token.string("") - start = self._find_token(token) - end = start + len(token.text) + start = token.start + end = token.end start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] @@ -918,8 +947,8 @@ class Parser(metaclass=_Parser): self.errors.append(error) def expression( - self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs - ) -> exp.Expression: + self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs + ) -> E: """ Creates a new, validated Expression. @@ -958,22 +987,7 @@ class Parser(metaclass=_Parser): self.raise_error(error_message) def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[self._find_token(start) : self._find_token(end) + len(end.text)] - - def _find_token(self, token: Token) -> int: - line = 1 - col = 1 - index = 0 - - while line < token.line or col < token.col: - if Tokenizer.WHITE_SPACE.get(self.sql[index]) == TokenType.BREAK: - line += 1 - col = 1 - else: - col += 1 - index += 1 - - return index + return self.sql[start.start : end.end] def _advance(self, times: int = 1) -> None: self._index += times @@ -990,7 +1004,7 @@ class Parser(metaclass=_Parser): if index != self._index: self._advance(index - self._index) - def _parse_command(self) -> exp.Expression: + def _parse_command(self) -> exp.Command: return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string()) def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: @@ -1007,7 +1021,7 @@ class Parser(metaclass=_Parser): if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=kind.token_type) elif kind.token_type == TokenType.TABLE: - this = self._parse_table() + this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS) elif kind.token_type == TokenType.COLUMN: this = self._parse_column() else: @@ -1035,16 +1049,13 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: + def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: start = self._prev temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: - if default_kind: - kind = default_kind - else: - return self._parse_as_command(start) + return self._parse_as_command(start) return self.expression( exp.Drop, @@ -1055,6 +1066,7 @@ class Parser(metaclass=_Parser): materialized=materialized, cascade=self._match(TokenType.CASCADE), constraints=self._match_text_seq("CONSTRAINTS"), + purge=self._match_text_seq("PURGE"), ) def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: @@ -1070,7 +1082,6 @@ class Parser(metaclass=_Parser): TokenType.OR, TokenType.REPLACE ) unique = self._match(TokenType.UNIQUE) - volatile = self._match(TokenType.VOLATILE) if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): self._match(TokenType.TABLE) @@ -1179,7 +1190,6 @@ class Parser(metaclass=_Parser): kind=create_token.text, replace=replace, unique=unique, - volatile=volatile, expression=expression, exists=exists, properties=properties, @@ -1225,6 +1235,21 @@ class Parser(metaclass=_Parser): return None + def _parse_stored(self) -> exp.Expression: + self._match(TokenType.ALIAS) + + input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None + output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None + + return self.expression( + exp.FileFormatProperty, + this=self.expression( + exp.InputOutputFormat, input_format=input_format, output_format=output_format + ) + if input_format or output_format + else self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), + ) + def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) @@ -1258,6 +1283,21 @@ class Parser(metaclass=_Parser): exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") ) + def _parse_volatile_property(self) -> exp.Expression: + if self._index >= 2: + pre_volatile_token = self._tokens[self._index - 2] + else: + pre_volatile_token = None + + if pre_volatile_token and pre_volatile_token.token_type in ( + TokenType.CREATE, + TokenType.REPLACE, + TokenType.UNIQUE, + ): + return exp.VolatileProperty() + + return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) + def _parse_with_property( self, ) -> t.Union[t.Optional[exp.Expression], t.List[t.Optional[exp.Expression]]]: @@ -1574,11 +1614,46 @@ class Parser(metaclass=_Parser): exists=self._parse_exists(), partition=self._parse_partition(), expression=self._parse_ddl_select(), + conflict=self._parse_on_conflict(), returning=self._parse_returning(), overwrite=overwrite, alternative=alternative, ) + def _parse_on_conflict(self) -> t.Optional[exp.Expression]: + conflict = self._match_text_seq("ON", "CONFLICT") + duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") + + if not (conflict or duplicate): + return None + + nothing = None + expressions = None + key = None + constraint = None + + if conflict: + if self._match_text_seq("ON", "CONSTRAINT"): + constraint = self._parse_id_var() + else: + key = self._parse_csv(self._parse_value) + + self._match_text_seq("DO") + if self._match_text_seq("NOTHING"): + nothing = True + else: + self._match(TokenType.UPDATE) + expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) + + return self.expression( + exp.OnConflict, + duplicate=duplicate, + expressions=expressions, + nothing=nothing, + key=key, + constraint=constraint, + ) + def _parse_returning(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.RETURNING): return None @@ -1639,7 +1714,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Delete, - this=self._parse_table(schema=True), + this=self._parse_table(), using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()), where=self._parse_where(), returning=self._parse_returning(), @@ -1792,6 +1867,7 @@ class Parser(metaclass=_Parser): if not skip_with_token and not self._match(TokenType.WITH): return None + comments = self._prev_comments recursive = self._match(TokenType.RECURSIVE) expressions = [] @@ -1803,7 +1879,9 @@ class Parser(metaclass=_Parser): else: self._match(TokenType.WITH) - return self.expression(exp.With, expressions=expressions, recursive=recursive) + return self.expression( + exp.With, comments=comments, expressions=expressions, recursive=recursive + ) def _parse_cte(self) -> exp.Expression: alias = self._parse_table_alias() @@ -1856,15 +1934,20 @@ class Parser(metaclass=_Parser): table = isinstance(this, exp.Table) while True: - lateral = self._parse_lateral() join = self._parse_join() - comma = None if table else self._match(TokenType.COMMA) - if lateral: - this.append("laterals", lateral) if join: this.append("joins", join) + + lateral = None + if not join: + lateral = self._parse_lateral() + if lateral: + this.append("laterals", lateral) + + comma = None if table else self._match(TokenType.COMMA) if comma: this.args["from"].append("expressions", self._parse_table()) + if not (lateral or join or comma): break @@ -1906,14 +1989,13 @@ class Parser(metaclass=_Parser): def _parse_match_recognize(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.MATCH_RECOGNIZE): return None + self._match_l_paren() partition = self._parse_partition_by() order = self._parse_order() measures = ( - self._parse_alias(self._parse_conjunction()) - if self._match_text_seq("MEASURES") - else None + self._parse_csv(self._parse_expression) if self._match_text_seq("MEASURES") else None ) if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): @@ -1967,8 +2049,17 @@ class Parser(metaclass=_Parser): pattern = None define = ( - self._parse_alias(self._parse_conjunction()) if self._match_text_seq("DEFINE") else None + self._parse_csv( + lambda: self.expression( + exp.Alias, + alias=self._parse_id_var(any_token=True), + this=self._match(TokenType.ALIAS) and self._parse_conjunction(), + ) + ) + if self._match_text_seq("DEFINE") + else None ) + self._match_r_paren() return self.expression( @@ -1980,6 +2071,7 @@ class Parser(metaclass=_Parser): after=after, pattern=pattern, define=define, + alias=self._parse_table_alias(), ) def _parse_lateral(self) -> t.Optional[exp.Expression]: @@ -2022,9 +2114,6 @@ class Parser(metaclass=_Parser): alias=table_alias, ) - if outer_apply or cross_apply: - return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT") - return expression def _parse_join_side_and_kind( @@ -2037,11 +2126,26 @@ class Parser(metaclass=_Parser): ) def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + index = self._index natural, side, kind = self._parse_join_side_and_kind() + hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None + join = self._match(TokenType.JOIN) - if not skip_join_token and not self._match(TokenType.JOIN): + if not skip_join_token and not join: + self._retreat(index) + kind = None + natural = None + side = None + + outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False) + cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False) + + if not skip_join_token and not join and not outer_apply and not cross_apply: return None + if outer_apply: + side = Token(TokenType.LEFT, "LEFT") + kwargs: t.Dict[ str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]] ] = {"this": self._parse_table()} @@ -2052,6 +2156,8 @@ class Parser(metaclass=_Parser): kwargs["side"] = side.text if kind: kwargs["kind"] = kind.text + if hint: + kwargs["hint"] = hint if self._match(TokenType.ON): kwargs["on"] = self._parse_conjunction() @@ -2179,7 +2285,7 @@ class Parser(metaclass=_Parser): return None expressions = self._parse_wrapped_csv(self._parse_column) - ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)) + ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) alias = self._parse_table_alias() if alias and self.unnest_column_only: @@ -2191,7 +2297,7 @@ class Parser(metaclass=_Parser): offset = None if self._match_pair(TokenType.WITH, TokenType.OFFSET): self._match(TokenType.ALIAS) - offset = self._parse_conjunction() + offset = self._parse_id_var() or exp.Identifier(this="offset") return self.expression( exp.Unnest, @@ -2294,6 +2400,9 @@ class Parser(metaclass=_Parser): else: expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function())) + if not expressions: + self.raise_error("Failed to parse PIVOT's aggregation list") + if not self._match(TokenType.FOR): self.raise_error("Expecting FOR") @@ -2311,8 +2420,26 @@ class Parser(metaclass=_Parser): if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): pivot.set("alias", self._parse_table_alias()) + if not unpivot: + names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions)) + + columns: t.List[exp.Expression] = [] + for col in pivot.args["field"].expressions: + for name in names: + if self.PREFIXED_PIVOT_COLUMNS: + name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name + else: + name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name + + columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS)) + + pivot.set("columns", columns) + return pivot + def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: + return [agg.alias for agg in pivot_columns] + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): return None @@ -2433,10 +2560,25 @@ class Parser(metaclass=_Parser): if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" + count = self._parse_number() + percent = self._match(TokenType.PERCENT) + self._match_set((TokenType.ROW, TokenType.ROWS)) - self._match(TokenType.ONLY) - return self.expression(exp.Fetch, direction=direction, count=count) + + only = self._match(TokenType.ONLY) + with_ties = self._match_text_seq("WITH", "TIES") + + if only and with_ties: + self.raise_error("Cannot specify both ONLY and WITH TIES in FETCH clause") + + return self.expression( + exp.Fetch, + direction=direction, + count=count, + percent=percent, + with_ties=with_ties, + ) return this @@ -2493,7 +2635,11 @@ class Parser(metaclass=_Parser): negate = self._match(TokenType.NOT) if self._match_set(self.RANGE_PARSERS): - this = self.RANGE_PARSERS[self._prev.token_type](self, this) + expression = self.RANGE_PARSERS[self._prev.token_type](self, this) + if not expression: + return this + + this = expression elif self._match(TokenType.ISNULL): this = self.expression(exp.Is, this=this, expression=exp.Null()) @@ -2511,17 +2657,19 @@ class Parser(metaclass=_Parser): return this - def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression: + def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + index = self._index - 1 negate = self._match(TokenType.NOT) if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) - this = self.expression( - exp.Is, - this=this, - expression=self._parse_null() or self._parse_boolean(), - ) + expression = self._parse_null() or self._parse_boolean() + if not expression: + self._retreat(index) + return None + + this = self.expression(exp.Is, this=this, expression=expression) return self.expression(exp.Not, this=this) if negate else this def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: @@ -2553,6 +2701,27 @@ class Parser(metaclass=_Parser): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) + def _parse_interval(self) -> t.Optional[exp.Expression]: + if not self._match(TokenType.INTERVAL): + return None + + this = self._parse_primary() or self._parse_term() + unit = self._parse_function() or self._parse_var() + + # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse + # each INTERVAL expression into this canonical form so it's easy to transpile + if this and isinstance(this, exp.Literal): + if this.is_number: + this = exp.Literal.string(this.name) + + # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year' + parts = this.name.split() + if not unit and len(parts) <= 2: + this = exp.Literal.string(seq_get(parts, 0)) + unit = self.expression(exp.Var, this=seq_get(parts, 1)) + + return self.expression(exp.Interval, this=this, unit=unit) + def _parse_bitwise(self) -> t.Optional[exp.Expression]: this = self._parse_term() @@ -2588,20 +2757,24 @@ class Parser(metaclass=_Parser): return self._parse_at_time_zone(self._parse_type()) def _parse_type(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.INTERVAL): - return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_field()) + interval = self._parse_interval() + if interval: + return interval index = self._index - type_token = self._parse_types(check_func=True) + data_type = self._parse_types(check_func=True) this = self._parse_column() - if type_token: + if data_type: if isinstance(this, exp.Literal): - return self.expression(exp.Cast, this=this, to=type_token) - if not type_token.args.get("expressions"): + parser = self.TYPE_LITERAL_PARSERS.get(data_type.this) + if parser: + return parser(self, this, data_type) + return self.expression(exp.Cast, this=this, to=data_type) + if not data_type.args.get("expressions"): self._retreat(index) return self._parse_column() - return type_token + return data_type return this @@ -2631,11 +2804,10 @@ class Parser(metaclass=_Parser): else: expressions = self._parse_csv(self._parse_conjunction) - if not expressions: + if not expressions or not self._match(TokenType.R_PAREN): self._retreat(index) return None - self._match_r_paren() maybe_func = True if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): @@ -2720,15 +2892,14 @@ class Parser(metaclass=_Parser): ) def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: - if self._curr and self._curr.token_type in self.TYPE_TOKENS: - return self._parse_types() - + index = self._index this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() if not data_type: - return None + self._retreat(index) + return self._parse_types() return self.expression(exp.StructKwarg, this=this, expression=data_type) def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: @@ -2825,6 +2996,7 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Paren, this=self._parse_set_operations(this)) self._match_r_paren() + comments.extend(self._prev_comments) if this and comments: this.comments = comments @@ -2833,8 +3005,16 @@ class Parser(metaclass=_Parser): return None - def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]: - return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) + def _parse_field( + self, + any_token: bool = False, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + return ( + self._parse_primary() + or self._parse_function() + or self._parse_id_var(any_token=any_token, tokens=tokens) + ) def _parse_function( self, functions: t.Optional[t.Dict[str, t.Callable]] = None @@ -3079,12 +3259,10 @@ class Parser(metaclass=_Parser): return None def _parse_column_constraint(self) -> t.Optional[exp.Expression]: - this = self._parse_references() - if this: - return this - if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() + else: + this = None if self._match_texts(self.CONSTRAINT_PARSERS): return self.expression( @@ -3164,8 +3342,8 @@ class Parser(metaclass=_Parser): return options - def _parse_references(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.REFERENCES): + def _parse_references(self, match=True) -> t.Optional[exp.Expression]: + if match and not self._match(TokenType.REFERENCES): return None expressions = None @@ -3234,7 +3412,7 @@ class Parser(metaclass=_Parser): elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: - expressions = apply_index_offset(expressions, -self.index_offset) + expressions = apply_index_offset(this, expressions, -self.index_offset) this = self.expression(exp.Bracket, this=this, expressions=expressions) if not self._match(TokenType.R_BRACKET) and bracket_kind == TokenType.L_BRACKET: @@ -3279,7 +3457,13 @@ class Parser(metaclass=_Parser): self.validate_expression(this, args) self._match_r_paren() else: + index = self._index - 1 condition = self._parse_conjunction() + + if not condition: + self._retreat(index) + return None + self._match(TokenType.THEN) true = self._parse_conjunction() false = self._parse_conjunction() if self._match(TokenType.ELSE) else None @@ -3591,14 +3775,24 @@ class Parser(metaclass=_Parser): # bigquery select from window x AS (partition by ...) if alias: + over = None self._match(TokenType.ALIAS) - elif not self._match(TokenType.OVER): + elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS): return this + else: + over = self._prev.text.upper() if not self._match(TokenType.L_PAREN): - return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) + return self.expression( + exp.Window, this=this, alias=self._parse_id_var(False), over=over + ) window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) + + first = self._match(TokenType.FIRST) + if self._match_text_seq("LAST"): + first = False + partition = self._parse_partition_by() order = self._parse_order() kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text @@ -3629,6 +3823,8 @@ class Parser(metaclass=_Parser): order=order, spec=spec, alias=window_alias, + over=over, + first=first, ) def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: @@ -3886,7 +4082,10 @@ class Parser(metaclass=_Parser): return expression def _parse_drop_column(self) -> t.Optional[exp.Expression]: - return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + drop = self._match(TokenType.DROP) and self._parse_drop() + if drop and not isinstance(drop, exp.Command): + drop.set("kind", drop.args.get("kind", "COLUMN")) + return drop # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.Expression: @@ -4010,7 +4209,7 @@ class Parser(metaclass=_Parser): if self._match(TokenType.INSERT): _this = self._parse_star() if _this: - then = self.expression(exp.Insert, this=_this) + then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=_this) else: then = self.expression( exp.Insert, @@ -4239,5 +4438,8 @@ class Parser(metaclass=_Parser): break parent = parent.parent else: - column.replace(dot_or_id) + if column is node: + node = dot_or_id + else: + column.replace(dot_or_id) return node diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 8e39c7f..5d60eb9 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,7 +5,7 @@ import typing as t import sqlglot from sqlglot import expressions as exp -from sqlglot.errors import SchemaError +from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie @@ -75,12 +75,11 @@ class AbstractMappingSchema(t.Generic[T]): mapping: dict | None = None, ) -> None: self.mapping = mapping or {} - self.mapping_trie = self._build_trie(self.mapping) + self.mapping_trie = new_trie( + tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self._depth()) + ) self._supported_table_args: t.Tuple[str, ...] = tuple() - def _build_trie(self, schema: t.Dict) -> t.Dict: - return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth())) - def _depth(self) -> int: return dict_depth(self.mapping) @@ -179,6 +178,64 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): } ) + def add_table( + self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + ) -> None: + """ + Register or update a table. Updates are only performed if a new column mapping is provided. + + Args: + table: the `Table` expression instance or string representing the table. + column_mapping: a column mapping that describes the structure of the table. + """ + normalized_table = self._normalize_table(self._ensure_table(table)) + normalized_column_mapping = { + self._normalize_name(key): value + for key, value in ensure_column_mapping(column_mapping).items() + } + + schema = self.find(normalized_table, raise_on_missing=False) + if schema and not normalized_column_mapping: + return + + parts = self.table_parts(normalized_table) + + _nested_set( + self.mapping, + tuple(reversed(parts)), + normalized_column_mapping, + ) + new_trie([parts], self.mapping_trie) + + def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + table_ = self._normalize_table(self._ensure_table(table)) + schema = self.find(table_) + + if schema is None: + return [] + + if not only_visible or not self.visible: + return list(schema) + + visible = self._nested_get(self.table_parts(table_), self.visible) + return [col for col in schema if col in visible] # type: ignore + + def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: + column_name = self._normalize_name(column if isinstance(column, str) else column.this) + table_ = self._normalize_table(self._ensure_table(table)) + + table_schema = self.find(table_, raise_on_missing=False) + if table_schema: + column_type = table_schema.get(column_name) + + if isinstance(column_type, exp.DataType): + return column_type + elif isinstance(column_type, str): + return self._to_data_type(column_type.upper()) + raise SchemaError(f"Unknown column type '{column_type}'") + + return exp.DataType.build("unknown") + def _normalize(self, schema: t.Dict) -> t.Dict: """ Converts all identifiers in the schema into lowercase, unless they're quoted. @@ -206,84 +263,37 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return normalized_mapping - def add_table( - self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None - ) -> None: - """ - Register or update a table. Updates are only performed if a new column mapping is provided. + def _normalize_table(self, table: exp.Table) -> exp.Table: + normalized_table = table.copy() + for arg in TABLE_ARGS: + value = normalized_table.args.get(arg) + if isinstance(value, (str, exp.Identifier)): + normalized_table.set(arg, self._normalize_name(value)) - Args: - table: the `Table` expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table. - """ - table_ = self._ensure_table(table) - column_mapping = ensure_column_mapping(column_mapping) - schema = self.find(table_, raise_on_missing=False) - - if schema and not column_mapping: - return - - _nested_set( - self.mapping, - list(reversed(self.table_parts(table_))), - column_mapping, - ) - self.mapping_trie = self._build_trie(self.mapping) + return normalized_table - def _normalize_name(self, name: str) -> str: + def _normalize_name(self, name: str | exp.Identifier) -> str: try: - identifier: t.Optional[exp.Expression] = sqlglot.parse_one( - name, read=self.dialect, into=exp.Identifier - ) - except: - identifier = exp.to_identifier(name) - assert isinstance(identifier, exp.Identifier) + identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) + except ParseError: + return name if isinstance(name, str) else name.name - if identifier.quoted: - return identifier.name - return identifier.name.lower() + return identifier.name if identifier.quoted else identifier.name.lower() def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 def _ensure_table(self, table: exp.Table | str) -> exp.Table: - table_ = exp.to_table(table) + if isinstance(table, exp.Table): + return table + table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) if not table_: raise SchemaError(f"Not a valid table '{table}'") return table_ - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: - table_ = self._ensure_table(table) - schema = self.find(table_) - - if schema is None: - return [] - - if not only_visible or not self.visible: - return list(schema) - - visible = self._nested_get(self.table_parts(table_), self.visible) - return [col for col in schema if col in visible] # type: ignore - - def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: - column_name = column if isinstance(column, str) else column.name - table_ = exp.to_table(table) - if table_: - table_schema = self.find(table_, raise_on_missing=False) - if table_schema: - column_type = table_schema.get(column_name) - - if isinstance(column_type, exp.DataType): - return column_type - elif isinstance(column_type, str): - return self._to_data_type(column_type.upper()) - raise SchemaError(f"Unknown column type '{column_type}'") - return exp.DataType(this=exp.DataType.Type.UNKNOWN) - raise SchemaError(f"Could not convert table '{table}'") - def _to_data_type(self, schema_type: str) -> exp.DataType: """ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. @@ -313,7 +323,7 @@ def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: return MappingSchema(schema, dialect=dialect) -def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): +def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: if isinstance(mapping, dict): return mapping elif isinstance(mapping, str): @@ -371,7 +381,7 @@ def _nested_get( return d -def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: +def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary @@ -384,11 +394,11 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict: Args: d: dictionary to update. - keys: the keys that makeup the path to `value`. - value: the value to set in the dictionary for the given key path. + keys: the keys that makeup the path to `value`. + value: the value to set in the dictionary for the given key path. - Returns: - The (possibly) updated dictionary. + Returns: + The (possibly) updated dictionary. """ if not keys: return d diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index cf2e31f..64c1f92 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -87,6 +87,7 @@ class TokenType(AutoName): FLOAT = auto() DOUBLE = auto() DECIMAL = auto() + BIGDECIMAL = auto() CHAR = auto() NCHAR = auto() VARCHAR = auto() @@ -214,6 +215,7 @@ class TokenType(AutoName): ISNULL = auto() JOIN = auto() JOIN_MARKER = auto() + KEEP = auto() LANGUAGE = auto() LATERAL = auto() LAZY = auto() @@ -231,6 +233,7 @@ class TokenType(AutoName): MOD = auto() NATURAL = auto() NEXT = auto() + NEXT_VALUE_FOR = auto() NO_ACTION = auto() NOTNULL = auto() NULL = auto() @@ -315,7 +318,7 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col", "comments") + __slots__ = ("token_type", "text", "line", "col", "end", "comments") @classmethod def number(cls, number: int) -> Token: @@ -343,22 +346,29 @@ class Token: text: str, line: int = 1, col: int = 1, + end: int = 0, comments: t.List[str] = [], ) -> None: self.token_type = token_type self.text = text self.line = line - self.col = col - len(text) - self.col = self.col if self.col > 1 else 1 + size = len(text) + self.col = col + self.end = end if end else size self.comments = comments + @property + def start(self) -> int: + """Returns the start of the token.""" + return self.end - len(self.text) + def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) return f"" class _Tokenizer(type): - def __new__(cls, clsname, bases, attrs): # type: ignore + def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) klass._QUOTES = { @@ -433,25 +443,25 @@ class Tokenizer(metaclass=_Tokenizer): "#": TokenType.HASH, } - QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] - BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] - - HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] - BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - + HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - - STRING_ESCAPES = ["'"] - - _STRING_ESCAPES: t.Set[str] = set() - IDENTIFIER_ESCAPES = ['"'] + QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] + STRING_ESCAPES = ["'"] + VAR_SINGLE_TOKENS: t.Set[str] = set() + _COMMENTS: t.Dict[str, str] = {} + _BIT_STRINGS: t.Dict[str, str] = {} + _BYTE_STRINGS: t.Dict[str, str] = {} + _HEX_STRINGS: t.Dict[str, str] = {} + _IDENTIFIERS: t.Dict[str, str] = {} _IDENTIFIER_ESCAPES: t.Set[str] = set() + _QUOTES: t.Dict[str, str] = {} + _STRING_ESCAPES: t.Set[str] = set() - KEYWORDS = { + KEYWORDS: t.Dict[t.Optional[str], TokenType] = { **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")}, "{{+": TokenType.BLOCK_START, @@ -553,6 +563,7 @@ class Tokenizer(metaclass=_Tokenizer): "IS": TokenType.IS, "ISNULL": TokenType.ISNULL, "JOIN": TokenType.JOIN, + "KEEP": TokenType.KEEP, "LATERAL": TokenType.LATERAL, "LAZY": TokenType.LAZY, "LEADING": TokenType.LEADING, @@ -565,6 +576,7 @@ class Tokenizer(metaclass=_Tokenizer): "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, + "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR, "NO ACTION": TokenType.NO_ACTION, "NOT": TokenType.NOT, "NOTNULL": TokenType.NOTNULL, @@ -632,6 +644,7 @@ class Tokenizer(metaclass=_Tokenizer): "UPDATE": TokenType.UPDATE, "USE": TokenType.USE, "USING": TokenType.USING, + "UUID": TokenType.UUID, "VALUES": TokenType.VALUES, "VIEW": TokenType.VIEW, "VOLATILE": TokenType.VOLATILE, @@ -661,6 +674,8 @@ class Tokenizer(metaclass=_Tokenizer): "INT8": TokenType.BIGINT, "DEC": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL, + "BIGDECIMAL": TokenType.BIGDECIMAL, + "BIGNUMERIC": TokenType.BIGDECIMAL, "MAP": TokenType.MAP, "NULLABLE": TokenType.NULLABLE, "NUMBER": TokenType.DECIMAL, @@ -742,7 +757,7 @@ class Tokenizer(metaclass=_Tokenizer): ENCODE: t.Optional[str] = None COMMENTS = ["--", ("/*", "*/"), ("{#", "#}")] - KEYWORD_TRIE = None # autofilled + KEYWORD_TRIE: t.Dict = {} # autofilled IDENTIFIER_CAN_START_WITH_DIGIT = False @@ -776,19 +791,28 @@ class Tokenizer(metaclass=_Tokenizer): self._col = 1 self._comments: t.List[str] = [] - self._char = None - self._end = None - self._peek = None + self._char = "" + self._end = False + self._peek = "" self._prev_token_line = -1 self._prev_token_comments: t.List[str] = [] - self._prev_token_type = None + self._prev_token_type: t.Optional[TokenType] = None def tokenize(self, sql: str) -> t.List[Token]: """Returns a list of tokens corresponding to the SQL string `sql`.""" self.reset() self.sql = sql self.size = len(sql) - self._scan() + try: + self._scan() + except Exception as e: + start = self._current - 50 + end = self._current + 50 + start = start if start > 0 else 0 + end = end if end < self.size else self.size - 1 + context = self.sql[start:end] + raise ValueError(f"Error tokenizing '{context}'") from e + return self.tokens def _scan(self, until: t.Optional[t.Callable] = None) -> None: @@ -810,9 +834,12 @@ class Tokenizer(metaclass=_Tokenizer): if until and until(): break + if self.tokens: + self.tokens[-1].comments.extend(self._comments) + def _chars(self, size: int) -> str: if size == 1: - return self._char # type: ignore + return self._char start = self._current - 1 end = start + size if end <= self.size: @@ -821,17 +848,15 @@ class Tokenizer(metaclass=_Tokenizer): def _advance(self, i: int = 1) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: - self._set_new_line() + self._col = 1 + self._line += 1 + else: + self._col += i - self._col += i self._current += i - self._end = self._current >= self.size # type: ignore - self._char = self.sql[self._current - 1] # type: ignore - self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore - - def _set_new_line(self) -> None: - self._col = 1 - self._line += 1 + self._end = self._current >= self.size + self._char = self.sql[self._current - 1] + self._peek = "" if self._end else self.sql[self._current] @property def _text(self) -> str: @@ -840,13 +865,14 @@ class Tokenizer(metaclass=_Tokenizer): def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line self._prev_token_comments = self._comments - self._prev_token_type = token_type # type: ignore + self._prev_token_type = token_type self.tokens.append( Token( token_type, self._text if text is None else text, self._line, self._col, + self._current, self._comments, ) ) @@ -881,7 +907,7 @@ class Tokenizer(metaclass=_Tokenizer): if skip: result = 1 else: - result, trie = in_trie(trie, char.upper()) # type: ignore + result, trie = in_trie(trie, char.upper()) if result == 0: break @@ -910,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer): if not word: if self._char in self.SINGLE_TOKENS: - self._add(self.SINGLE_TOKENS[self._char], text=self._char) # type: ignore + self._add(self.SINGLE_TOKENS[self._char], text=self._char) return self._scan_var() return @@ -927,29 +953,31 @@ class Tokenizer(metaclass=_Tokenizer): self._add(self.KEYWORDS[word], text=word) def _scan_comment(self, comment_start: str) -> bool: - if comment_start not in self._COMMENTS: # type: ignore + if comment_start not in self._COMMENTS: return False comment_start_line = self._line comment_start_size = len(comment_start) - comment_end = self._COMMENTS[comment_start] # type: ignore + comment_end = self._COMMENTS[comment_start] if comment_end: - comment_end_size = len(comment_end) + # Skip the comment's start delimiter + self._advance(comment_start_size) + comment_end_size = len(comment_end) while not self._end and self._chars(comment_end_size) != comment_end: self._advance() - self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore + self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) self._advance(comment_end_size - 1) else: while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK: self._advance() - self._comments.append(self._text[comment_start_size:]) # type: ignore + self._comments.append(self._text[comment_start_size:]) # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. # Multiple consecutive comments are preserved by appending them to the current comments list. - if comment_start_line == self._prev_token_line or self._end: + if comment_start_line == self._prev_token_line: self.tokens[-1].comments.extend(self._comments) self._comments = [] self._prev_token_line = self._line @@ -958,7 +986,7 @@ class Tokenizer(metaclass=_Tokenizer): def _scan_number(self) -> None: if self._char == "0": - peek = self._peek.upper() # type: ignore + peek = self._peek.upper() if peek == "B": return self._scan_bits() elif peek == "X": @@ -968,7 +996,7 @@ class Tokenizer(metaclass=_Tokenizer): scientific = 0 while True: - if self._peek.isdigit(): # type: ignore + if self._peek.isdigit(): self._advance() elif self._peek == "." and not decimal: decimal = True @@ -976,24 +1004,23 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() - elif self._peek.upper() == "E" and not scientific: # type: ignore + elif self._peek.upper() == "E" and not scientific: scientific += 1 self._advance() - elif self._peek.isidentifier(): # type: ignore + elif self._peek.isidentifier(): number_text = self._text - literal = [] + literal = "" - while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore - literal.append(self._peek.upper()) # type: ignore + while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: + literal += self._peek.upper() self._advance() - literal = "".join(literal) # type: ignore - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) if token_type: self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") - return self._add(token_type, literal) # type: ignore + return self._add(token_type, literal) elif self.IDENTIFIER_CAN_START_WITH_DIGIT: return self._add(TokenType.VAR) @@ -1020,7 +1047,7 @@ class Tokenizer(metaclass=_Tokenizer): def _extract_value(self) -> str: while True: - char = self._peek.strip() # type: ignore + char = self._peek.strip() if char and char not in self.SINGLE_TOKENS: self._advance() else: @@ -1029,35 +1056,35 @@ class Tokenizer(metaclass=_Tokenizer): return self._text def _scan_string(self, quote: str) -> bool: - quote_end = self._QUOTES.get(quote) # type: ignore + quote_end = self._QUOTES.get(quote) if quote_end is None: return False self._advance(len(quote)) text = self._extract_string(quote_end) - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. def _scan_formatted_string(self, string_start: str) -> bool: - if string_start in self._HEX_STRINGS: # type: ignore - delimiters = self._HEX_STRINGS # type: ignore + if string_start in self._HEX_STRINGS: + delimiters = self._HEX_STRINGS token_type = TokenType.HEX_STRING base = 16 - elif string_start in self._BIT_STRINGS: # type: ignore - delimiters = self._BIT_STRINGS # type: ignore + elif string_start in self._BIT_STRINGS: + delimiters = self._BIT_STRINGS token_type = TokenType.BIT_STRING base = 2 - elif string_start in self._BYTE_STRINGS: # type: ignore - delimiters = self._BYTE_STRINGS # type: ignore + elif string_start in self._BYTE_STRINGS: + delimiters = self._BYTE_STRINGS token_type = TokenType.BYTE_STRING base = None else: return False self._advance(len(string_start)) - string_end = delimiters.get(string_start) + string_end = delimiters[string_start] text = self._extract_string(string_end) if base is None: @@ -1083,20 +1110,20 @@ class Tokenizer(metaclass=_Tokenizer): self._advance() if self._char == identifier_end: if identifier_end_is_escape and self._peek == identifier_end: - text += identifier_end # type: ignore + text += identifier_end self._advance() continue break - text += self._char # type: ignore + text += self._char self._add(TokenType.IDENTIFIER, text) def _scan_var(self) -> None: while True: - char = self._peek.strip() # type: ignore - if char and char not in self.SINGLE_TOKENS: + char = self._peek.strip() + if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS): self._advance() else: break @@ -1115,9 +1142,9 @@ class Tokenizer(metaclass=_Tokenizer): self._peek == delimiter or self._peek in self._STRING_ESCAPES ): if self._peek == delimiter: - text += self._peek # type: ignore + text += self._peek else: - text += self._char + self._peek # type: ignore + text += self._char + self._peek if self._current + 1 < self.size: self._advance(2) @@ -1131,7 +1158,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._end: raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") - text += self._char # type: ignore + text += self._char self._advance() return text diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 62728d5..00f278e 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -103,7 +103,11 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: if isinstance(expr, exp.Window): alias = find_new_name(expression.named_selects, "_w") expression.select(exp.alias_(expr.copy(), alias), copy=False) - expr.replace(exp.column(alias)) + column = exp.column(alias) + if isinstance(expr.parent, exp.Qualify): + qualify_filters = column + else: + expr.replace(column) elif expr.name not in expression.named_selects: expression.select(expr.copy(), copy=False) @@ -133,9 +137,111 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr ) +def unnest_to_explode(expression: exp.Expression) -> exp.Expression: + """Convert cross join unnest into lateral view explode (used in presto -> hive).""" + if isinstance(expression, exp.Select): + for join in expression.args.get("joins") or []: + unnest = join.this + + if isinstance(unnest, exp.Unnest): + alias = unnest.args.get("alias") + udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + + expression.args["joins"].remove(join) + + for e, column in zip(unnest.expressions, alias.columns if alias else []): + expression.append( + "laterals", + exp.Lateral( + this=udtf(this=e), + view=True, + alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore + ), + ) + return expression + + +def explode_to_unnest(expression: exp.Expression) -> exp.Expression: + """Convert explode/posexplode into unnest (used in hive -> presto).""" + if isinstance(expression, exp.Select): + from sqlglot.optimizer.scope import build_scope + + taken_select_names = set(expression.named_selects) + taken_source_names = set(build_scope(expression).selected_sources) + + for select in expression.selects: + to_replace = select + + pos_alias = "" + explode_alias = "" + + if isinstance(select, exp.Alias): + explode_alias = select.alias + select = select.this + elif isinstance(select, exp.Aliases): + pos_alias = select.aliases[0].name + explode_alias = select.aliases[1].name + select = select.this + + if isinstance(select, (exp.Explode, exp.Posexplode)): + is_posexplode = isinstance(select, exp.Posexplode) + + explode_arg = select.this + unnest = exp.Unnest(expressions=[explode_arg.copy()], ordinality=is_posexplode) + + # This ensures that we won't use [POS]EXPLODE's argument as a new selection + if isinstance(explode_arg, exp.Column): + taken_select_names.add(explode_arg.output_name) + + unnest_source_alias = find_new_name(taken_source_names, "_u") + taken_source_names.add(unnest_source_alias) + + if not explode_alias: + explode_alias = find_new_name(taken_select_names, "col") + taken_select_names.add(explode_alias) + + if is_posexplode: + pos_alias = find_new_name(taken_select_names, "pos") + taken_select_names.add(pos_alias) + + if is_posexplode: + column_names = [explode_alias, pos_alias] + to_replace.pop() + expression.select(pos_alias, explode_alias, copy=False) + else: + column_names = [explode_alias] + to_replace.replace(exp.column(explode_alias)) + + unnest = exp.alias_(unnest, unnest_source_alias, table=column_names) + + if not expression.args.get("from"): + expression.from_(unnest, copy=False) + else: + expression.join(unnest, join_type="CROSS", copy=False) + + return expression + + +def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: + """Remove table refs from columns in when statements.""" + if isinstance(expression, exp.Merge): + alias = expression.this.args.get("alias") + targets = {expression.this.this} + if alias: + targets.add(alias.this) + + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.name) + if isinstance(node, exp.Column) and node.args.get("table") in targets + else node, + copy=False, + ) + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], - to_sql: t.Callable[[Generator, exp.Expression], str], ) -> t.Callable[[Generator, exp.Expression], str]: """ Creates a new transform by chaining a sequence of transformations and converts the resulting @@ -143,36 +249,23 @@ def preprocess( Args: transforms: sequence of transform functions. These will be called in order. - to_sql: final transform that converts the resulting expression to a SQL string. Returns: Function that can be used as a generator transform. """ - def _to_sql(self, expression): + def _to_sql(self, expression: exp.Expression) -> str: expression = transforms[0](expression.copy()) for t in transforms[1:]: expression = t(expression) - return to_sql(self, expression) + return getattr(self, expression.key + "_sql")(expression) return _to_sql -def delegate(attr: str) -> t.Callable: - """ - Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS` - functions that delegate to existing generator methods. - """ - - def _transform(self, *args, **kwargs): - return getattr(self, attr)(*args, **kwargs) - - return _transform - - -UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} -ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} -ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))} +UNALIAS_GROUP = {exp.Group: preprocess([unalias_group])} +ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on])} +ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify])} REMOVE_PRECISION_PARAMETERIZED_TYPES = { - exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql")) + exp.Cast: preprocess([remove_precision_parameterized_types]) } diff --git a/sqlglot/trie.py b/sqlglot/trie.py index f3b1c38..eba91b9 100644 --- a/sqlglot/trie.py +++ b/sqlglot/trie.py @@ -3,7 +3,7 @@ import typing as t key = t.Sequence[t.Hashable] -def new_trie(keywords: t.Iterable[key]) -> t.Dict: +def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict: """ Creates a new trie out of a collection of keywords. @@ -16,11 +16,12 @@ def new_trie(keywords: t.Iterable[key]) -> t.Dict: Args: keywords: the keywords to create the trie from. + trie: a trie to mutate instead of creating a new one Returns: The trie corresponding to `keywords`. """ - trie: t.Dict = {} + trie = {} if trie is None else trie for key in keywords: current = trie -- cgit v1.2.3