diff options
Diffstat (limited to 'sqlglot')
25 files changed, 540 insertions, 181 deletions
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py index f4cfeba..fcfd71e 100644 --- a/sqlglot/dataframe/sql/column.py +++ b/sqlglot/dataframe/sql/column.py @@ -114,7 +114,7 @@ class Column: return self.inverse_binary_op(exp.Or, other) @classmethod - def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): + def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: return cls(value) @classmethod @@ -259,7 +259,7 @@ class Column: new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) return Column(new_expression) - def cast(self, dataType: t.Union[str, DataType]): + def cast(self, dataType: t.Union[str, DataType]) -> Column: """ Functionality Difference: PySpark cast accepts a datatype instance of the datatype class Sqlglot doesn't currently replicate this class so it only accepts a string diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index bdc1fb4..1549a07 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -600,8 +600,13 @@ def months_between( date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None ) -> Column: if roundOff is None: - return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2) - return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff) + return Column.invoke_expression_over_column( + date1, expression.MonthsBetween, expression=date2 + ) + + return Column.invoke_expression_over_column( + date1, expression.MonthsBetween, expression=date2, roundoff=roundOff + ) def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: @@ -614,8 +619,9 @@ def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: if format is not None: - return Column.invoke_anonymous_function(col, "TO_TIMESTAMP", lit(format)) - return Column.invoke_anonymous_function(col, "TO_TIMESTAMP") + return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format)) + + return Column.ensure_col(col).cast("timestamp") def trunc(col: ColumnOrName, format: str) -> Column: @@ -875,8 +881,16 @@ def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) ) -def regexp_replace(str: ColumnOrName, pattern: str, replacement: str) -> Column: - return Column.invoke_anonymous_function(str, "REGEXP_REPLACE", lit(pattern), lit(replacement)) +def regexp_replace( + str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None +) -> Column: + return Column.invoke_expression_over_column( + str, + expression.RegexpReplace, + expression=lit(pattern), + replacement=lit(replacement), + position=position, + ) def initcap(col: ColumnOrName) -> Column: @@ -1186,7 +1200,9 @@ def transform( f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], ) -> Column: f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression)) + return Column.invoke_expression_over_column( + col, expression.Transform, expression=Column(f_expression) + ) def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 35892f7..fd9965c 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -8,6 +8,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, + binary_from_function, datestrtodate_sql, format_time_lambda, inline_array_sql, @@ -15,6 +16,7 @@ from sqlglot.dialects.dialect import ( min_or_least, no_ilike_sql, parse_date_delta_with_interval, + regexp_replace_sql, rename_func, timestrtotime_sql, ts_or_ds_to_date_sql, @@ -39,7 +41,7 @@ def _date_add_sql( def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: - if not isinstance(expression.unnest().parent, exp.From): + if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) alias = expression.args.get("alias") @@ -279,7 +281,7 @@ class BigQuery(Dialect): ), "DATETIME_ADD": parse_date_delta_with_interval(exp.DatetimeAdd), "DATETIME_SUB": parse_date_delta_with_interval(exp.DatetimeSub), - "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), + "DIV": binary_from_function(exp.IntDiv), "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, "MD5": exp.MD5Digest.from_arg_list, "TO_HEX": _parse_to_hex, @@ -415,6 +417,7 @@ class BigQuery(Dialect): e.args.get("position"), e.args.get("occurrence"), ), + exp.RegexpReplace: regexp_replace_sql, exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.ReturnsProperty: _returnsproperty_sql, exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 9126c4b..8f60df2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -64,6 +64,7 @@ class ClickHouse(Dialect): "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, + "XOR": lambda args: exp.Xor(expressions=args), } FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} @@ -95,6 +96,7 @@ class ClickHouse(Dialect): TokenType.ASOF, TokenType.ANTI, TokenType.SEMI, + TokenType.ARRAY, } TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { @@ -103,6 +105,7 @@ class ClickHouse(Dialect): TokenType.ANTI, TokenType.SETTINGS, TokenType.FORMAT, + TokenType.ARRAY, } LOG_DEFAULTS_TO_LN = True @@ -160,8 +163,11 @@ class ClickHouse(Dialect): schema: bool = False, joins: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, ) -> t.Optional[exp.Expression]: - this = super()._parse_table(schema=schema, joins=joins, alias_tokens=alias_tokens) + this = super()._parse_table( + schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket + ) if self._match(TokenType.FINAL): this = self.expression(exp.Final, this=this) @@ -204,8 +210,10 @@ class ClickHouse(Dialect): self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]: - join = super()._parse_join(skip_join_token) + def _parse_join( + self, skip_join_token: bool = False, parse_bracket: bool = False + ) -> t.Optional[exp.Join]: + join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True) if join: join.set("global", join.args.pop("method", None)) @@ -318,6 +326,7 @@ class ClickHouse(Dialect): exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), + exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 5376dff..8c84639 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -12,6 +12,8 @@ from sqlglot.time import format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie +B = t.TypeVar("B", bound=exp.Binary) + class Dialects(str, Enum): DIALECT = "" @@ -630,6 +632,16 @@ def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: ) +def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: + bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) + if bad_args: + self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") + + return self.func( + "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] + ) + + def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] for agg in aggregations: @@ -650,3 +662,7 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) return names + + +def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: + return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 1d8a7fb..219b1aa 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + binary_from_function, date_trunc_to_time, datestrtodate_sql, format_time_lambda, @@ -16,6 +17,7 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, pivot_column_names, regexp_extract_sql, + regexp_replace_sql, rename_func, str_position_sql, str_to_time_sql, @@ -103,7 +105,6 @@ class DuckDB(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, - "~": TokenType.RLIKE, ":=": TokenType.EQ, "//": TokenType.DIV, "ATTACH": TokenType.COMMAND, @@ -128,6 +129,11 @@ class DuckDB(Dialect): class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True + BITWISE = { + **parser.Parser.BITWISE, + TokenType.TILDA: exp.RegexpLike, + } + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, @@ -158,6 +164,7 @@ class DuckDB(Dialect): "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, "UNNEST": exp.Explode.from_arg_list, + "XOR": binary_from_function(exp.BitwiseXor), } TYPE_TOKENS = { @@ -190,6 +197,7 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), + exp.BitwiseXor: lambda self, e: self.func("XOR", e.this, e.expression), exp.CommentColumnConstraint: no_comment_column_constraint_sql, exp.CurrentDate: lambda self, e: "CURRENT_DATE", exp.CurrentTime: lambda self, e: "CURRENT_TIME", @@ -203,7 +211,7 @@ class DuckDB(Dialect): exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateSub: _date_delta_sql, exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this + "DATE_DIFF", f"'{e.args.get('unit') or 'day'}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", @@ -217,8 +225,15 @@ class DuckDB(Dialect): exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), + exp.MonthsBetween: lambda self, e: self.func( + "DATEDIFF", + "'month'", + exp.cast(e.expression, "timestamp"), + exp.cast(e.this, "timestamp"), + ), exp.Properties: no_properties_sql, exp.RegexpExtract: regexp_extract_sql, + exp.RegexpReplace: regexp_replace_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index f968f6a..e131434 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -18,6 +18,7 @@ from sqlglot.dialects.dialect import ( no_safe_divide_sql, no_trycast_sql, regexp_extract_sql, + regexp_replace_sql, rename_func, right_to_substring_sql, strposition_to_locate_sql, @@ -211,6 +212,7 @@ class Hive(Dialect): "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, "MSCK REPAIR": TokenType.COMMAND, + "REFRESH": TokenType.COMMAND, "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } @@ -270,6 +272,11 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "TRANSFORM": lambda self: self._parse_transform(), + } + PROPERTY_PARSERS = { **parser.Parser.PROPERTY_PARSERS, "WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties( @@ -277,6 +284,40 @@ class Hive(Dialect): ), } + def _parse_transform(self) -> exp.Transform | exp.QueryTransform: + args = self._parse_csv(self._parse_lambda) + self._match_r_paren() + + row_format_before = self._parse_row_format(match_row=True) + + record_writer = None + if self._match_text_seq("RECORDWRITER"): + record_writer = self._parse_string() + + if not self._match(TokenType.USING): + return exp.Transform.from_arg_list(args) + + command_script = self._parse_string() + + self._match(TokenType.ALIAS) + schema = self._parse_schema() + + row_format_after = self._parse_row_format(match_row=True) + record_reader = None + if self._match_text_seq("RECORDREADER"): + record_reader = self._parse_string() + + return self.expression( + exp.QueryTransform, + expressions=args, + command_script=command_script, + schema=schema, + row_format_before=row_format_before, + record_writer=record_writer, + row_format_after=row_format_after, + record_reader=record_reader, + ) + def _parse_types( self, check_func: bool = False, schema: bool = False ) -> t.Optional[exp.Expression]: @@ -363,11 +404,13 @@ class Hive(Dialect): exp.Max: max_or_greatest, exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, + exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpExtract: regexp_extract_sql, + exp.RegexpReplace: regexp_replace_sql, exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.Right: right_to_substring_sql, @@ -396,7 +439,6 @@ class Hive(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", - exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.LastDateOfMonth: rename_func("LAST_DAY"), @@ -410,6 +452,11 @@ class Hive(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str: + serde_props = self.sql(expression, "serde_properties") + serde_props = f" {serde_props}" if serde_props else "" + return f"ROW FORMAT SERDE {self.sql(expression, 'this')}{serde_props}" + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index e4de934..5d65f77 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -427,6 +427,7 @@ class MySQL(Dialect): TABLE_HINTS = True DUPLICATE_KEY_UPDATE_WITH_SET = False QUERY_HINT_SEP = " " + VALUES_AS_TABLE = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -473,19 +474,32 @@ class MySQL(Dialect): LIMIT_FETCH = "LIMIT" + # MySQL doesn't support many datatypes in cast. + # https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast + CAST_MAPPING = { + exp.DataType.Type.BIGINT: "SIGNED", + exp.DataType.Type.BOOLEAN: "SIGNED", + exp.DataType.Type.INT: "SIGNED", + exp.DataType.Type.TEXT: "CHAR", + exp.DataType.Type.UBIGINT: "UNSIGNED", + exp.DataType.Type.VARCHAR: "CHAR", + } + + def xor_sql(self, expression: exp.Xor) -> str: + if expression.expressions: + return self.expressions(expression, sep=" XOR ") + return super().xor_sql(expression) + def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str: return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead.""" - if expression.to.this == exp.DataType.Type.BIGINT: - to = "SIGNED" - elif expression.to.this == exp.DataType.Type.UBIGINT: - to = "UNSIGNED" - else: - return super().cast_sql(expression) + to = self.CAST_MAPPING.get(expression.to.this) - return f"CAST({self.sql(expression, 'this')} AS {to})" + if to: + expression = expression.copy() + expression.to.set("this", to) + return super().cast_sql(expression) def show_sql(self, expression: exp.Show) -> str: this = f" {expression.name}" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 7706456..d11cbd7 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -282,7 +282,6 @@ class Postgres(Dialect): VAR_SINGLE_TOKENS = {"$"} class Parser(parser.Parser): - STRICT_CAST = False CONCAT_NULL_OUTPUTS_STRING = True FUNCTIONS = { @@ -318,6 +317,11 @@ class Postgres(Dialect): TokenType.LT_AT: binary_range_parser(exp.ArrayContained), } + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.END: lambda self: self._parse_commit_or_rollback(), + } + def _parse_factor(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_exponent, self.FACTOR) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 7d35c67..265c6e5 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + binary_from_function, date_trunc_to_time, format_time_lambda, if_sql, @@ -198,6 +199,10 @@ class Presto(Dialect): **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, + "BITWISE_AND": binary_from_function(exp.BitwiseAnd), + "BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)), + "BITWISE_OR": binary_from_function(exp.BitwiseOr), + "BITWISE_XOR": binary_from_function(exp.BitwiseXor), "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 09edd55..f687ba7 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -27,6 +27,11 @@ class Redshift(Postgres): class Parser(Postgres.Parser): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, + "ADD_MONTHS": lambda args: exp.DateAdd( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=seq_get(args, 1), + unit=exp.var("month"), + ), "DATEADD": lambda args: exp.DateAdd( this=exp.TsOrDsToDate(this=seq_get(args, 2)), expression=seq_get(args, 1), @@ -37,7 +42,6 @@ class Redshift(Postgres): expression=exp.TsOrDsToDate(this=seq_get(args, 1)), unit=seq_get(args, 0), ), - "NVL": exp.Coalesce.from_arg_list, "STRTOL": exp.FromBase.from_arg_list, } @@ -87,6 +91,7 @@ class Redshift(Postgres): LOCKING_READS_SUPPORTED = False RENAME_TABLE_WITH_DB = False QUERY_HINTS = False + VALUES_AS_TABLE = False TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -129,40 +134,6 @@ class Redshift(Postgres): RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} - def values_sql(self, expression: exp.Values) -> str: - """ - Converts `VALUES...` expression into a series of unions. - - Note: If you have a lot of unions then this will result in a large number of recursive statements to - evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be - very slow. - """ - - # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example - if not expression.find_ancestor(exp.From, exp.Join): - return super().values_sql(expression) - - column_names = expression.alias and expression.args["alias"].columns - - selects = [] - rows = [tuple_exp.expressions for tuple_exp in expression.expressions] - - for i, row in enumerate(rows): - if i == 0 and column_names: - row = [ - exp.alias_(value, column_name) - for value, column_name in zip(row, column_names) - ] - - selects.append(exp.Select(expressions=row)) - - subquery_expression: exp.Select | exp.Union = selects[0] - if len(selects) > 1: - for select in selects[1:]: - subquery_expression = exp.union(subquery_expression, select, distinct=False) - - return self.subquery_sql(subquery_expression.subquery(expression.alias)) - def with_properties(self, properties: exp.Properties) -> str: """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" return self.properties(properties, prefix=" ", suffix="") diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 715a84c..499e085 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -30,7 +30,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _parse_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -137,7 +137,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.List) -> exp.Expression: +def _div0_to_if(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -145,13 +145,13 @@ def _div0_to_if(args: t.List) -> exp.Expression: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.List) -> exp.Expression: +def _zeroifnull_to_if(args: t.List) -> exp.If: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.List) -> exp.Expression: +def _nullifzero_to_if(args: t.List) -> exp.If: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -164,12 +164,21 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _parse_convert_timezone(args: t.List) -> exp.Expression: +def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) +def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: + regexp_replace = exp.RegexpReplace.from_arg_list(args) + + if not regexp_replace.args.get("replacement"): + regexp_replace.set("replacement", exp.Literal.string("")) + + return regexp_replace + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax RESOLVES_IDENTIFIERS_AS_UPPERCASE = True @@ -223,13 +232,14 @@ class Snowflake(Dialect): "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, + "REGEXP_REPLACE": _parse_regexp_replace, "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, "TO_ARRAY": exp.Array.from_arg_list, - "TO_TIMESTAMP": _snowflake_to_timestamp, + "TO_TIMESTAMP": _parse_to_timestamp, "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, } @@ -242,7 +252,6 @@ class Snowflake(Dialect): FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, - TokenType.RLIKE, TokenType.TABLE, } diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index f909e8c..dcaa524 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -2,9 +2,11 @@ from __future__ import annotations import typing as t -from sqlglot import exp, parser, transforms +from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( + binary_from_function, create_with_partitions_sql, + format_time_lambda, pivot_column_names, rename_func, trim_sql, @@ -108,47 +110,36 @@ class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { **Hive.Parser.FUNCTIONS, - "MAP_FROM_ARRAYS": exp.Map.from_arg_list, - "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, - "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ), - "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ), - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "IIF": exp.If.from_arg_list, "AGGREGATE": exp.Reduce.from_arg_list, - "DAYOFWEEK": lambda args: exp.DayOfWeek( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DAYOFMONTH": lambda args: exp.DayOfMonth( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DAYOFYEAR": lambda args: exp.DayOfYear( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "WEEKOFYEAR": lambda args: exp.WeekOfYear( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DATE_TRUNC": lambda args: exp.TimestampTrunc( - this=seq_get(args, 1), - unit=exp.var(seq_get(args, 0)), - ), - "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "BOOLEAN": _parse_as_cast("boolean"), "DATE": _parse_as_cast("date"), + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)) + ), + "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DOUBLE": _parse_as_cast("double"), "FLOAT": _parse_as_cast("float"), + "IIF": exp.If.from_arg_list, "INT": _parse_as_cast("int"), + "MAP_FROM_ARRAYS": exp.Map.from_arg_list, + "RLIKE": exp.RegexpLike.from_arg_list, + "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift), + "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift), "STRING": _parse_as_cast("string"), "TIMESTAMP": _parse_as_cast("timestamp"), + "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args) + if len(args) == 1 + else format_time_lambda(exp.StrToTime, "spark")(args), + "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **Hive.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -207,6 +198,13 @@ class Spark2(Hive): exp.Map: _map_sql, exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), exp.Reduce: rename_func("AGGREGATE"), + exp.RegexpReplace: lambda self, e: self.func( + "REGEXP_REPLACE", + e.this, + e.expression, + e.args["replacement"], + e.args.get("position"), + ), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimestampTrunc: lambda self, e: self.func( @@ -224,6 +222,7 @@ class Spark2(Hive): TRANSFORMS.pop(exp.ArraySort) TRANSFORMS.pop(exp.ILike) TRANSFORMS.pop(exp.Left) + TRANSFORMS.pop(exp.MonthsBetween) TRANSFORMS.pop(exp.Right) WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 0390113..baa62e8 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -20,6 +20,8 @@ class StarRocks(MySQL): } class Generator(MySQL.Generator): + CAST_MAPPING = {} + TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b77c2c0..01d5001 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim if isinstance(expression, exp.NumberToStr) else exp.Literal.string( format_time( - expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING) + expression.text("format"), + t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), ) ) ) @@ -314,7 +315,9 @@ class TSQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -365,6 +368,55 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: + """Applies to SQL Server and Azure SQL Database + COMMIT [ { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable ] ] + [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] + + ROLLBACK { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable + | savepoint_name | @savepoint_variable ] + """ + rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts({"TRAN", "TRANSACTION"}) + this = self._parse_id_var() + + if rollback: + return self.expression(exp.Rollback, this=this) + + durability = None + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): + self._match_text_seq("DELAYED_DURABILITY") + self._match(TokenType.EQ) + + if self._match_text_seq("OFF"): + durability = False + else: + self._match(TokenType.ON) + durability = True + + self._match_r_paren() + + return self.expression(exp.Commit, this=this, durability=durability) + + def _parse_transaction(self) -> exp.Transaction | exp.Command: + """Applies to SQL Server and Azure SQL Database + BEGIN { TRAN | TRANSACTION } + [ { transaction_name | @tran_name_variable } + [ WITH MARK [ 'description' ] ] + ] + """ + if self._match_texts(("TRAN", "TRANSACTION")): + transaction = self.expression(exp.Transaction, this=self._parse_id_var()) + if self._match_text_seq("WITH", "MARK"): + transaction.set("mark", self._parse_string()) + + return transaction + + return self._parse_as_command(self._prev) + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -496,7 +548,9 @@ class TSQL(Dialect): exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( - "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this + "HASHBYTES", + exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), + e.this, ), exp.TimeToStr: _format_sql, } @@ -539,3 +593,26 @@ class TSQL(Dialect): into = self.sql(expression, "into") into = self.seg(f"INTO {into}") if into else "" return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" + + def transaction_sql(self, expression: exp.Transaction) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + mark = self.sql(expression, "mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{this}{mark}" + + def commit_sql(self, expression: exp.Commit) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + durability = expression.args.get("durability") + durability = ( + f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})" + if durability is not None + else "" + ) + return f"COMMIT TRANSACTION{this}{durability}" + + def rollback_sql(self, expression: exp.Rollback) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"ROLLBACK TRANSACTION{this}" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 264b8e9..9a6b440 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -759,12 +759,24 @@ class Condition(Expression): ) def isin( - self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts + self, + *expressions: t.Any, + query: t.Optional[ExpOrStr] = None, + unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None, + copy: bool = True, + **opts, ) -> In: return In( this=_maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], query=maybe_parse(query, copy=copy, **opts) if query else None, + unnest=Unnest( + expressions=[ + maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest) + ] + ) + if unnest + else None, ) def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between: @@ -2019,7 +2031,20 @@ class RowFormatDelimitedProperty(Property): class RowFormatSerdeProperty(Property): - arg_types = {"this": True} + arg_types = {"this": True, "serde_properties": False} + + +# https://spark.apache.org/docs/3.1.2/sql-ref-syntax-qry-select-transform.html +class QueryTransform(Expression): + arg_types = { + "expressions": True, + "command_script": True, + "schema": False, + "row_format_before": False, + "record_writer": False, + "row_format_after": False, + "record_reader": False, + } class SchemaCommentProperty(Property): @@ -2149,12 +2174,24 @@ class Tuple(Expression): arg_types = {"expressions": False} def isin( - self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts + self, + *expressions: t.Any, + query: t.Optional[ExpOrStr] = None, + unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None, + copy: bool = True, + **opts, ) -> In: return In( this=_maybe_copy(self, copy), expressions=[convert(e, copy=copy) for e in expressions], query=maybe_parse(query, copy=copy, **opts) if query else None, + unnest=Unnest( + expressions=[ + maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) for e in ensure_list(unnest) + ] + ) + if unnest + else None, ) @@ -3478,15 +3515,15 @@ class Command(Expression): class Transaction(Expression): - arg_types = {"this": False, "modes": False} + arg_types = {"this": False, "modes": False, "mark": False} class Commit(Expression): - arg_types = {"chain": False} + arg_types = {"chain": False, "this": False, "durability": False} class Rollback(Expression): - arg_types = {"savepoint": False} + arg_types = {"savepoint": False, "this": False} class AlterTable(Expression): @@ -3530,10 +3567,6 @@ class Or(Connector): pass -class Xor(Connector): - pass - - class BitwiseAnd(Binary): pass @@ -3856,6 +3889,11 @@ class Abs(Func): pass +# https://spark.apache.org/docs/latest/api/sql/index.html#transform +class Transform(Func): + arg_types = {"this": True, "expression": True} + + class Anonymous(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True @@ -4098,6 +4136,10 @@ class WeekOfYear(Func): _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] +class MonthsBetween(Func): + arg_types = {"this": True, "expression": True, "roundoff": False} + + class LastDateOfMonth(Func): pass @@ -4209,6 +4251,10 @@ class Hex(Func): pass +class Xor(Connector, Func): + arg_types = {"this": False, "expression": False, "expressions": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -4431,7 +4477,18 @@ class RegexpExtract(Func): } -class RegexpLike(Func): +class RegexpReplace(Func): + arg_types = { + "this": True, + "expression": True, + "replacement": True, + "position": False, + "occurrence": False, + "parameters": False, + } + + +class RegexpLike(Binary, Func): arg_types = {"this": True, "expression": True, "flag": False} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 857eff1..40ba88e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -164,6 +164,11 @@ class Generator: # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") + # Whether or not VALUES statements can be used as derived tables. + # MySQL 5 and Redshift do not allow this, so when False, it will convert + # SELECT * VALUES into SELECT UNION + VALUES_AS_TABLE = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -260,8 +265,9 @@ class Generator: # Expressions whose comments are separated from them for better formatting WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( - exp.Select, + exp.Drop, exp.From, + exp.Select, exp.Where, exp.With, ) @@ -818,7 +824,11 @@ class Generator: def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this - type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) + type_sql = ( + self.TYPE_MAPPING.get(type_value, type_value.value) + if isinstance(type_value, exp.DataType.Type) + else type_value + ) nested = "" interior = self.expressions(expression, flat=True) values = "" @@ -1307,15 +1317,45 @@ class Generator: return self.prepend_ctes(expression, sql) def values_sql(self, expression: exp.Values) -> str: - args = self.expressions(expression) - alias = self.sql(expression, "alias") - values = f"VALUES{self.seg('')}{args}" - values = ( - f"({values})" - if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From)) - else values - ) - return f"{values} AS {alias}" if alias else values + # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example + if self.VALUES_AS_TABLE or not expression.find_ancestor(exp.From, exp.Join): + args = self.expressions(expression) + alias = self.sql(expression, "alias") + values = f"VALUES{self.seg('')}{args}" + values = ( + f"({values})" + if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From)) + else values + ) + return f"{values} AS {alias}" if alias else values + + # Converts `VALUES...` expression into a series of select unions. + # Note: If you have a lot of unions then this will result in a large number of recursive statements to + # evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be + # very slow. + expression = expression.copy() + column_names = expression.alias and expression.args["alias"].columns + + selects = [] + + for i, tup in enumerate(expression.expressions): + row = tup.expressions + + if i == 0 and column_names: + row = [ + exp.alias_(value, column_name) for value, column_name in zip(row, column_names) + ] + + selects.append(exp.Select(expressions=row)) + + subquery_expression: exp.Select | exp.Union = selects[0] + if len(selects) > 1: + for select in selects[1:]: + subquery_expression = exp.union( + subquery_expression, select, distinct=False, copy=False + ) + + return self.subquery_sql(subquery_expression.subquery(expression.alias, copy=False)) def var_sql(self, expression: exp.Var) -> str: return self.sql(expression, "this") @@ -2043,7 +2083,7 @@ class Generator: def and_sql(self, expression: exp.And) -> str: return self.connector_sql(expression, "AND") - def xor_sql(self, expression: exp.And) -> str: + def xor_sql(self, expression: exp.Xor) -> str: return self.connector_sql(expression, "XOR") def connector_sql(self, expression: exp.Connector, op: str) -> str: @@ -2507,6 +2547,21 @@ class Generator: return self.func("ANY_VALUE", this) + def querytransform_sql(self, expression: exp.QueryTransform) -> str: + transform = self.func("TRANSFORM", *expression.expressions) + row_format_before = self.sql(expression, "row_format_before") + row_format_before = f" {row_format_before}" if row_format_before else "" + record_writer = self.sql(expression, "record_writer") + record_writer = f" RECORDWRITER {record_writer}" if record_writer else "" + using = f" USING {self.sql(expression, 'command_script')}" + schema = self.sql(expression, "schema") + schema = f" AS {schema}" if schema else "" + row_format_after = self.sql(expression, "row_format_after") + row_format_after = f" {row_format_after}" if row_format_after else "" + record_reader = self.sql(expression, "record_reader") + record_reader = f" RECORDREADER {record_reader}" if record_reader else "" + return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}" + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 9f5ae9a..113458f 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -79,7 +79,7 @@ def lineage( raise SqlglotError("Cannot build lineage, sql must be SELECT") def to_node( - column_name: str, + column: str | int, scope: Scope, scope_name: t.Optional[str] = None, upstream: t.Optional[Node] = None, @@ -90,26 +90,38 @@ def lineage( for dt in scope.derived_tables if dt.comments and dt.comments[0].startswith("source: ") } - if isinstance(scope.expression, exp.Union): - for scope in scope.union_scopes: - node = to_node( - column_name, - scope=scope, - scope_name=scope_name, - upstream=upstream, - alias=aliases.get(scope_name), - ) - return node # Find the specific select clause that is the source of the column we want. # This can either be a specific, named select or a generic `*` clause. - select = next( - (select for select in scope.expression.selects if select.alias_or_name == column_name), - exp.Star() if scope.expression.is_star else None, + select = ( + scope.expression.selects[column] + if isinstance(column, int) + else next( + (select for select in scope.expression.selects if select.alias_or_name == column), + exp.Star() if scope.expression.is_star else None, + ) ) if not select: - raise ValueError(f"Could not find {column_name} in {scope.expression}") + raise ValueError(f"Could not find {column} in {scope.expression}") + + if isinstance(scope.expression, exp.Union): + upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) + + index = ( + column + if isinstance(column, int) + else next( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column + ) + ) + + for s in scope.union_scopes: + to_node(index, scope=s, upstream=upstream) + + return upstream if isinstance(scope.expression, exp.Select): # For better ergonomics in our node labels, replace the full select with @@ -122,7 +134,7 @@ def lineage( # Create the node for this step in the lineage chain, and attach it to the previous one. node = Node( - name=f"{scope_name}.{column_name}" if scope_name else column_name, + name=f"{scope_name}.{column}" if scope_name else str(column), source=source, expression=select, alias=alias or "", diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 5ae1fa0..728493d 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -144,8 +144,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken): name, cte = _new_cte(scope, existing_ctes, taken) table = exp.alias_(exp.table_(name), alias=parent.alias or name) - parent.replace(table) + table.set("joins", parent.args.get("joins")) + parent.replace(table) return cte diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 6ee057b..7322424 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -176,6 +176,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): return ( isinstance(outer_scope.expression, exp.Select) + and not outer_scope.expression.is_star and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") @@ -242,6 +243,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): alias (str) """ new_subquery = inner_scope.expression.args["from"].this + new_subquery.set("joins", node_to_replace.args.get("joins")) node_to_replace.replace(new_subquery) for join_hint in outer_scope.join_hints: tables = join_hint.find_all(exp.Table) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 97e8ff6..c81fd00 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -61,6 +61,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) if remove_unused_selections: _remove_unused_selections(scope, parent_selections, schema) + if scope.expression.is_star: + continue + # Group columns by source name selects = defaultdict(set) for col in scope.columns: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 7972b2b..2657188 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -29,12 +29,13 @@ def qualify_columns( 'SELECT tbl.col AS col FROM tbl' Args: - expression: expression to qualify - schema: Database schema - expand_alias_refs: whether or not to expand references to aliases - infer_schema: whether or not to infer the schema if missing + expression: Expression to qualify. + schema: Database schema. + expand_alias_refs: Whether or not to expand references to aliases. + infer_schema: Whether or not to infer the schema if missing. + Returns: - sqlglot.Expression: qualified expression + The qualified expression. """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema @@ -410,7 +411,9 @@ def _expand_stars( else: return - scope.expression.set("expressions", new_selections) + # Ensures we don't overwrite the initial selections with an empty list + if new_selections: + scope.expression.set("expressions", new_selections) def _add_except_columns( diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b2b4230..a7dab35 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -124,8 +124,8 @@ class Scope: self._ctes.append(node) elif ( isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join)) - and _is_subquery_scope(node) + and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) + and _is_derived_table(node) ): self._derived_tables.append(node) elif isinstance(node, exp.Subqueryable): @@ -610,13 +610,13 @@ def _traverse_ctes(scope): scope.sources.update(sources) -def _is_subquery_scope(expression: exp.Subquery) -> bool: +def _is_derived_table(expression: exp.Subquery) -> bool: """ - We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope. - If an alias is present, it shadows all names under the Subquery, so that's an - exception to this rule. + We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", + as it doesn't introduce a new scope. If an alias is present, it shadows all names + under the Subquery, so that's one exception to this rule. """ - return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias) + return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) def _traverse_tables(scope): @@ -654,7 +654,10 @@ def _traverse_tables(scope): else: sources[source_name] = expression - expressions.extend(join.this for join in expression.args.get("joins") or []) + # Make sure to not include the joins twice + if expression is not scope.expression: + expressions.extend(join.this for join in expression.args.get("joins") or []) + continue if not isinstance(expression, exp.DerivedTable): @@ -664,10 +667,11 @@ def _traverse_tables(scope): lateral_sources = sources scope_type = ScopeType.UDTF scopes = scope.udtf_scopes - elif _is_subquery_scope(expression): + elif _is_derived_table(expression): lateral_sources = None scope_type = ScopeType.DERIVED_TABLE scopes = scope.derived_table_scopes + expressions.extend(join.this for join in expression.args.get("joins") or []) else: # Makes sure we check for possible sources in nested table constructs expressions.append(expression.this) @@ -735,10 +739,16 @@ def walk_in_scope(expression, bfs=True): isinstance(node, exp.CTE) or ( isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join)) - and _is_subquery_scope(node) + and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) + and _is_derived_table(node) ) or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): prune = True + + if isinstance(node, (exp.Subquery, exp.UDTF)): + # The following args are not actually in the inner scope, so we should visit them + for key in ("joins", "laterals", "pivots"): + for arg in node.args.get(key) or []: + yield from walk_in_scope(arg, bfs=bfs) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 508a273..5adec77 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -327,6 +327,7 @@ class Parser(metaclass=_Parser): TokenType.PRIMARY_KEY, TokenType.RANGE, TokenType.REPLACE, + TokenType.RLIKE, TokenType.ROW, TokenType.UNNEST, TokenType.VAR, @@ -338,6 +339,7 @@ class Parser(metaclass=_Parser): TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.WINDOW, + TokenType.XOR, *TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -505,7 +507,6 @@ class Parser(metaclass=_Parser): TokenType.DESC: lambda self: self._parse_describe(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), - TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.FROM: lambda self: exp.select("*").from_( t.cast(exp.From, self._parse_from(skip_from_token=True)) ), @@ -716,7 +717,7 @@ class Parser(metaclass=_Parser): FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} - FUNCTION_PARSERS: t.Dict[str, t.Callable] = { + FUNCTION_PARSERS = { "ANY_VALUE": lambda self: self._parse_any_value(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONCAT": lambda self: self._parse_concat(), @@ -1144,6 +1145,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Drop, + comments=start.comments, exists=self._parse_exists(), this=self._parse_table(schema=True), kind=kind, @@ -1233,11 +1235,14 @@ class Parser(metaclass=_Parser): expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: + # exp.Properties.Location.POST_EXPRESSION + extend_props(self._parse_properties()) + indexes = [] while True: index = self._parse_index() - # exp.Properties.Location.POST_EXPRESSION and POST_INDEX + # exp.Properties.Location.POST_INDEX extend_props(self._parse_properties()) if not index: @@ -1384,7 +1389,6 @@ class Parser(metaclass=_Parser): def _parse_with_property( self, ) -> t.Optional[exp.Expression] | t.List[t.Optional[exp.Expression]]: - self._match(TokenType.WITH) if self._match(TokenType.L_PAREN, advance=False): return self._parse_wrapped_csv(self._parse_property) @@ -1781,7 +1785,17 @@ class Parser(metaclass=_Parser): return None if self._match_text_seq("SERDE"): - return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string()) + this = self._parse_string() + + serde_properties = None + if self._match(TokenType.SERDE_PROPERTIES): + serde_properties = self.expression( + exp.SerdeProperties, expressions=self._parse_wrapped_csv(self._parse_property) + ) + + return self.expression( + exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties + ) self._match_text_seq("DELIMITED") @@ -2251,7 +2265,9 @@ class Parser(metaclass=_Parser): self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Join]: + def _parse_join( + self, skip_join_token: bool = False, parse_bracket: bool = False + ) -> t.Optional[exp.Join]: if self._match(TokenType.COMMA): return self.expression(exp.Join, this=self._parse_table()) @@ -2275,7 +2291,7 @@ class Parser(metaclass=_Parser): if outer_apply: side = Token(TokenType.LEFT, "LEFT") - kwargs: t.Dict[str, t.Any] = {"this": self._parse_table()} + kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)} if method: kwargs["method"] = method.text @@ -2411,6 +2427,7 @@ class Parser(metaclass=_Parser): schema: bool = False, joins: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None, + parse_bracket: bool = False, ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -2430,7 +2447,9 @@ class Parser(metaclass=_Parser): subquery.set("pivots", self._parse_pivots()) return subquery - this: exp.Expression = self._parse_table_parts(schema=schema) + bracket = parse_bracket and self._parse_bracket(None) + bracket = self.expression(exp.Table, this=bracket) if bracket else None + this: exp.Expression = bracket or self._parse_table_parts(schema=schema) if schema: return self._parse_schema(this=this) @@ -2758,8 +2777,15 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression] = None, top: bool = False ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): - limit_paren = self._match(TokenType.L_PAREN) - expression = self._parse_number() if top else self._parse_term() + comments = self._prev_comments + if top: + limit_paren = self._match(TokenType.L_PAREN) + expression = self._parse_number() + + if limit_paren: + self._match_r_paren() + else: + expression = self._parse_term() if self._match(TokenType.COMMA): offset = expression @@ -2767,10 +2793,9 @@ class Parser(metaclass=_Parser): else: offset = None - limit_exp = self.expression(exp.Limit, this=this, expression=expression, offset=offset) - - if limit_paren: - self._match_r_paren() + limit_exp = self.expression( + exp.Limit, this=this, expression=expression, offset=offset, comments=comments + ) return limit_exp @@ -2803,7 +2828,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.OFFSET): return this - count = self._parse_number() + count = self._parse_term() self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) @@ -3320,7 +3345,7 @@ class Parser(metaclass=_Parser): else: this = self.expression(exp.Anonymous, this=this, expressions=args) - self._match_r_paren(this) + self._match(TokenType.R_PAREN, expression=this) return self._parse_window(this) def _parse_function_parameter(self) -> t.Optional[exp.Expression]: @@ -4076,7 +4101,10 @@ class Parser(metaclass=_Parser): self, this: t.Optional[exp.Expression], alias: bool = False ) -> t.Optional[exp.Expression]: if self._match_pair(TokenType.FILTER, TokenType.L_PAREN): - this = self.expression(exp.Filter, this=this, expression=self._parse_where()) + self._match(TokenType.WHERE) + this = self.expression( + exp.Filter, this=this, expression=self._parse_where(skip_where_token=True) + ) self._match_r_paren() # T-SQL allows the OVER (...) syntax after WITHIN GROUP. @@ -4351,7 +4379,7 @@ class Parser(metaclass=_Parser): self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) - def _parse_transaction(self) -> exp.Transaction: + def _parse_transaction(self) -> exp.Transaction | exp.Command: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index ed14594..a19ebaa 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t from enum import auto +from sqlglot.errors import TokenError from sqlglot.helper import AutoName from sqlglot.trie import TrieResult, in_trie, new_trie @@ -800,7 +801,7 @@ class Tokenizer(metaclass=_Tokenizer): start = max(self._current - 50, 0) end = min(self._current + 50, self.size - 1) context = self.sql[start:end] - raise ValueError(f"Error tokenizing '{context}'") from e + raise TokenError(f"Error tokenizing '{context}'") from e return self.tokens @@ -1097,7 +1098,7 @@ class Tokenizer(metaclass=_Tokenizer): try: int(text, base) except: - raise RuntimeError( + raise TokenError( f"Numeric string contains invalid characters from {self._line}:{self._start}" ) else: @@ -1140,7 +1141,7 @@ class Tokenizer(metaclass=_Tokenizer): if self._current + 1 < self.size: self._advance(2) else: - raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._current}") + raise TokenError(f"Missing {delimiter} from {self._line}:{self._current}") else: if self._chars(delim_size) == delimiter: if delim_size > 1: @@ -1148,7 +1149,7 @@ class Tokenizer(metaclass=_Tokenizer): break if self._end: - raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}") + raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") current = self._current - 1 self._advance(alnum=True) |