diff options
Diffstat (limited to 'sqlglot')
37 files changed, 1388 insertions, 760 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 2207a28..e30232c 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -88,13 +88,11 @@ def parse( @t.overload -def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: - ... +def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ... @t.overload -def parse_one(sql: str, **opts) -> Expression: - ... +def parse_one(sql: str, **opts) -> Expression: ... def parse_one( diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 7e3f07b..0bacbf9 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -140,12 +140,10 @@ class DataFrame: return cte, name @t.overload - def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: - ... + def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: ... @t.overload - def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: - ... + def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: ... def _ensure_list_of_columns(self, cols): return Column.ensure_cols(ensure_list(cols)) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 308b639..db5201f 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -210,7 +210,7 @@ def sec(col: ColumnOrName) -> Column: def signum(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SIGNUM") + return Column.invoke_expression_over_column(col, expression.Sign) def sin(col: ColumnOrName) -> Column: @@ -592,7 +592,7 @@ def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) + return Column.invoke_expression_over_column(start, expression.AddMonths, expression=months) def months_between( diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index f867617..5bfc3ea 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -42,7 +42,10 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va alias = expression.args.get("alias") for tup in expression.find_all(exp.Tuple): field_aliases = alias.columns if alias else (f"_c{i}" for i in range(len(tup.expressions))) - expressions = [exp.alias_(fld, name) for fld, name in zip(tup.expressions, field_aliases)] + expressions = [ + exp.PropertyEQ(this=exp.to_identifier(name), expression=fld) + for name, fld in zip(field_aliases, tup.expressions) + ] structs.append(exp.Struct(expressions=expressions)) return self.unnest_sql(exp.Unnest(expressions=[exp.array(*structs, copy=False)])) @@ -111,6 +114,8 @@ def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: } for grouped in group.expressions: + if grouped.is_int: + continue alias = aliases.get(grouped) if alias: grouped.replace(exp.column(alias)) @@ -226,8 +231,11 @@ class BigQuery(Dialect): # bigquery udfs are case sensitive NORMALIZE_FUNCTIONS = False + # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time TIME_MAPPING = { "%D": "%m/%d/%y", + "%E*S": "%S.%f", + "%E6S": "%S.%f", } ESCAPE_SEQUENCES = { @@ -266,14 +274,20 @@ class BigQuery(Dialect): while isinstance(parent, exp.Dot): parent = parent.parent - # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). - # The following check is essentially a heuristic to detect tables based on whether or - # not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive. - if ( - not isinstance(parent, exp.UserDefinedFunction) - and not (isinstance(parent, exp.Table) and parent.db) - and not expression.meta.get("is_table") - ): + # In BigQuery, CTEs are case-insensitive, but UDF and table names are case-sensitive + # by default. The following check uses a heuristic to detect tables based on whether + # they are qualified. This should generally be correct, because tables in BigQuery + # must be qualified with at least a dataset, unless @@dataset_id is set. + case_sensitive = ( + isinstance(parent, exp.UserDefinedFunction) + or ( + isinstance(parent, exp.Table) + and parent.db + and (parent.meta.get("quoted_table") or not parent.meta.get("maybe_column")) + ) + or expression.meta.get("is_table") + ) + if not case_sensitive: expression.set("this", expression.this.lower()) return expression @@ -302,6 +316,7 @@ class BigQuery(Dialect): "BYTES": TokenType.BINARY, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "DECLARE": TokenType.COMMAND, + "ELSEIF": TokenType.COMMAND, "EXCEPTION": TokenType.COMMAND, "FLOAT64": TokenType.DOUBLE, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, @@ -315,8 +330,8 @@ class BigQuery(Dialect): class Parser(parser.Parser): PREFIXED_PIVOT_COLUMNS = True - LOG_DEFAULTS_TO_LN = True + SUPPORTS_IMPLICIT_UNNEST = True FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -410,6 +425,7 @@ class BigQuery(Dialect): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, + TokenType.ELSE: lambda self: self._parse_as_command(self._prev), TokenType.END: lambda self: self._parse_as_command(self._prev), TokenType.FOR: lambda self: self._parse_for_in(), } @@ -433,8 +449,11 @@ class BigQuery(Dialect): if isinstance(this, exp.Identifier): table_name = this.name while self._match(TokenType.DASH, advance=False) and self._next: - self._advance(2) - table_name += f"-{self._prev.text}" + text = "" + while self._curr and self._curr.token_type != TokenType.DOT: + self._advance() + text += self._prev.text + table_name += text this = exp.Identifier(this=table_name, quoted=this.args.get("quoted")) elif isinstance(this, exp.Literal): @@ -448,12 +467,28 @@ class BigQuery(Dialect): return this def _parse_table_parts( - self, schema: bool = False, is_db_reference: bool = False + self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False ) -> exp.Table: - table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference) + table = super()._parse_table_parts( + schema=schema, is_db_reference=is_db_reference, wildcard=True + ) + + # proj-1.db.tbl -- `1.` is tokenized as a float so we need to unravel it here + if not table.catalog: + if table.db: + parts = table.db.split(".") + if len(parts) == 2 and not table.args["db"].quoted: + table.set("catalog", exp.Identifier(this=parts[0])) + table.set("db", exp.Identifier(this=parts[1])) + else: + parts = table.name.split(".") + if len(parts) == 2 and not table.this.quoted: + table.set("db", exp.Identifier(this=parts[0])) + table.set("this", exp.Identifier(this=parts[1])) + if isinstance(table.this, exp.Identifier) and "." in table.name: catalog, db, this, *rest = ( - t.cast(t.Optional[exp.Expression], exp.to_identifier(x)) + t.cast(t.Optional[exp.Expression], exp.to_identifier(x, quoted=True)) for x in split_num_words(table.name, ".", 3) ) @@ -461,16 +496,15 @@ class BigQuery(Dialect): this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest])) table = exp.Table(this=this, db=db, catalog=catalog) + table.meta["quoted_table"] = True return table @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: - ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: - ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... def _parse_json_object(self, agg=False): json_object = super()._parse_json_object() @@ -532,6 +566,7 @@ class BigQuery(Dialect): IGNORE_NULLS_IN_FUNC = True JSON_PATH_SINGLE_QUOTE_ESCAPE = True CAN_IMPLEMENT_ARRAY_ANY = True + NAMED_PLACEHOLDER_TOKEN = "@" TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -762,22 +797,25 @@ class BigQuery(Dialect): "within", } + def table_parts(self, expression: exp.Table) -> str: + # Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so + # we need to make sure the correct quoting is used in each case. + # + # For example, if there is a CTE x that clashes with a schema name, then the former will + # return the table y in that schema, whereas the latter will return the CTE's y column: + # + # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x.y` -> cross join + # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x`.`y` -> implicit unnest + if expression.meta.get("quoted_table"): + table_parts = ".".join(p.name for p in expression.parts) + return self.sql(exp.Identifier(this=table_parts, quoted=True)) + + return super().table_parts(expression) + def timetostr_sql(self, expression: exp.TimeToStr) -> str: this = expression.this if isinstance(expression.this, exp.TsOrDsToDate) else expression return self.func("FORMAT_DATE", self.format_time(expression), this.this) - def struct_sql(self, expression: exp.Struct) -> str: - args = [] - for expr in expression.expressions: - if isinstance(expr, self.KEY_VALUE_DEFINITIONS): - arg = f"{self.sql(expr, 'expression')} AS {expr.this.name}" - else: - arg = self.sql(expr) - - args.append(arg) - - return self.func("STRUCT", *args) - def eq_sql(self, expression: exp.EQ) -> str: # Operands of = cannot be NULL in BigQuery if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): @@ -803,7 +841,7 @@ class BigQuery(Dialect): def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) - if isinstance(first_arg, exp.Subqueryable): + if isinstance(first_arg, exp.Query): return f"ARRAY{self.wrap(self.sql(first_arg))}" return inline_array_sql(self, expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 05d6a03..90167f6 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -68,7 +68,6 @@ class ClickHouse(Dialect): "DATE32": TokenType.DATE32, "DATETIME64": TokenType.DATETIME64, "DICTIONARY": TokenType.DICTIONARY, - "ENUM": TokenType.ENUM, "ENUM8": TokenType.ENUM8, "ENUM16": TokenType.ENUM16, "FINAL": TokenType.FINAL, @@ -93,6 +92,7 @@ class ClickHouse(Dialect): "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION, "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION, "SYSTEM": TokenType.COMMAND, + "PREWHERE": TokenType.PREWHERE, } SINGLE_TOKENS = { @@ -129,6 +129,7 @@ class ClickHouse(Dialect): "MAP": parser.build_var_map, "MATCH": exp.RegexpLike.from_arg_list, "RANDCANONICAL": exp.Rand.from_arg_list, + "TUPLE": exp.Struct.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, "XOR": lambda args: exp.Xor(expressions=args), } @@ -390,7 +391,7 @@ class ClickHouse(Dialect): return self.expression( exp.CTE, - this=self._parse_field(), + this=self._parse_conjunction(), alias=self._parse_table_alias(), scalar=True, ) @@ -732,3 +733,7 @@ class ClickHouse(Dialect): return f"{this_name}{self.sep()}{this_properties}{self.sep()}{this_schema}" return super().createable_sql(expression, locations) + + def prewhere_sql(self, expression: exp.PreWhere) -> str: + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('PREWHERE')}{self.sep()}{this}" diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 96eff18..188b6a7 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -69,7 +69,7 @@ class Databricks(Spark): def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint) - kind = expression.args.get("kind") + kind = expression.kind if ( constraint and isinstance(kind, exp.DataType) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b0a78d2..599505c 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -443,7 +443,7 @@ class Dialect(metaclass=_Dialect): identify: If set to `False`, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy. """ - if isinstance(expression, exp.Identifier): + if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): name = expression.this expression.set( "quoted", diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 067a045..9a84848 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -21,6 +21,7 @@ class Doris(MySQL): **MySQL.Parser.FUNCTIONS, "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, "DATE_TRUNC": build_timestamp_trunc, + "MONTHS_ADD": exp.AddMonths.from_arg_list, "REGEXP": exp.RegexpLike.from_arg_list, "TO_DATE": exp.TsOrDsToDate.from_arg_list, } @@ -41,6 +42,7 @@ class Doris(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, + exp.AddMonths: rename_func("MONTHS_ADD"), exp.ApproxDistinct: approx_count_distinct_sql, exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), @@ -58,7 +60,6 @@ class Doris(MySQL): exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), exp.Split: rename_func("SPLIT_BY_STRING"), exp.TimeStrToDate: rename_func("TO_DATE"), - exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression), exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 4e699f5..c1f6afa 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -156,6 +156,3 @@ class Drill(Dialect): exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", } - - def normalize_func(self, name: str) -> str: - return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 925c5ae..f74dc97 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -79,6 +79,21 @@ def _build_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) +def _build_generate_series(end_exclusive: bool = False) -> t.Callable[[t.List], exp.GenerateSeries]: + def _builder(args: t.List) -> exp.GenerateSeries: + # Check https://duckdb.org/docs/sql/functions/nested.html#range-functions + if len(args) == 1: + # DuckDB uses 0 as a default for the series' start when it's omitted + args.insert(0, exp.Literal.number("0")) + + gen_series = exp.GenerateSeries.from_arg_list(args) + gen_series.set("is_end_exclusive", end_exclusive) + + return gen_series + + return _builder + + def _build_make_timestamp(args: t.List) -> exp.Expression: if len(args) == 1: return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS) @@ -95,13 +110,13 @@ def _build_make_timestamp(args: t.List) -> exp.Expression: def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: args: t.List[str] = [] - for expr in expression.expressions: - if isinstance(expr, exp.Alias): - key = expr.alias - value = expr.this - else: - key = expr.name or expr.this.name + for i, expr in enumerate(expression.expressions): + if isinstance(expr, exp.PropertyEQ): + key = expr.name value = expr.expression + else: + key = f"_{i}" + value = expr args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}") @@ -148,13 +163,6 @@ def _rename_unless_within_group( ) -def _build_struct_pack(args: t.List) -> exp.Struct: - args_with_columns_as_identifiers = [ - exp.PropertyEQ(this=arg.this.this, expression=arg.expression) for arg in args - ] - return exp.Struct.from_arg_list(args_with_columns_as_identifiers) - - class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" SUPPORTS_USER_DEFINED_TYPES = False @@ -189,6 +197,7 @@ class DuckDB(Dialect): "CHARACTER VARYING": TokenType.TEXT, "EXCLUDE": TokenType.EXCEPT, "LOGICAL": TokenType.BOOLEAN, + "ONLY": TokenType.ONLY, "PIVOT_WIDER": TokenType.PIVOT, "SIGNED": TokenType.INT, "STRING": TokenType.VARCHAR, @@ -213,6 +222,8 @@ class DuckDB(Dialect): TokenType.TILDA: exp.RegexpLike, } + FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "STRUCT_PACK"} + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAY_HAS": exp.ArrayContains.from_arg_list, @@ -261,12 +272,14 @@ class DuckDB(Dialect): "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_TO_ARRAY": exp.Split.from_arg_list, "STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"), - "STRUCT_PACK": _build_struct_pack, + "STRUCT_PACK": exp.Struct.from_arg_list, "STR_SPLIT": exp.Split.from_arg_list, "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, "UNNEST": exp.Explode.from_arg_list, "XOR": binary_from_function(exp.BitwiseXor), + "GENERATE_SERIES": _build_generate_series(), + "RANGE": _build_generate_series(end_exclusive=True), } FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() @@ -313,6 +326,8 @@ class DuckDB(Dialect): return pivot_column_names(aggregations, dialect="duckdb") class Generator(generator.Generator): + PARAMETER_TOKEN = "$" + NAMED_PLACEHOLDER_TOKEN = "$" JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False @@ -535,5 +550,22 @@ class DuckDB(Dialect): return self.sql(expression, "this") return super().columndef_sql(expression, sep) - def placeholder_sql(self, expression: exp.Placeholder) -> str: - return f"${expression.name}" if expression.name else "?" + def join_sql(self, expression: exp.Join) -> str: + if ( + expression.side == "LEFT" + and not expression.args.get("on") + and isinstance(expression.this, exp.Unnest) + ): + # Some dialects support `LEFT JOIN UNNEST(...)` without an explicit ON clause + # DuckDB doesn't, but we can just add a dummy ON clause that is always true + return super().join_sql(expression.on(exp.true())) + + return super().join_sql(expression) + + def generateseries_sql(self, expression: exp.GenerateSeries) -> str: + # GENERATE_SERIES(a, b) -> [a, b], RANGE(a, b) -> [a, b) + if expression.args.get("is_end_exclusive"): + expression.set("is_end_exclusive", None) + return rename_func("RANGE")(self, expression) + + return super().generateseries_sql(expression) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 43211dc..55a9254 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -140,6 +140,15 @@ def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression)) +def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: + timestamp = self.sql(expression, "this") + scale = expression.args.get("scale") + if scale in (None, exp.UnixToTime.SECONDS): + return rename_func("FROM_UNIXTIME")(self, expression) + + return f"FROM_UNIXTIME({timestamp} / POW(10, {scale}))" + + def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) @@ -536,7 +545,7 @@ class Hive(Dialect): exp.UnixToStr: lambda self, e: self.func( "FROM_UNIXTIME", e.this, time_format("hive")(self, e) ), - exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTime: _unix_to_time_sql, exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), @@ -609,9 +618,8 @@ class Hive(Dialect): return self.properties(properties, prefix=self.seg("TBLPROPERTIES")) def datatype_sql(self, expression: exp.DataType) -> str: - if ( - expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) - and not expression.expressions + if expression.this in self.PARAMETERIZABLE_TEXT_TYPES and ( + not expression.expressions or expression.expressions[0].name == "MAX" ): expression = exp.DataType.build("text") elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions: @@ -631,3 +639,15 @@ class Hive(Dialect): def version_sql(self, expression: exp.Version) -> str: sql = super().version_sql(expression) return sql.replace("FOR ", "", 1) + + def struct_sql(self, expression: exp.Struct) -> str: + values = [] + + for i, e in enumerate(expression.expressions): + if isinstance(e, exp.PropertyEQ): + self.unsupported("Hive does not support named structs.") + values.append(e.expression) + else: + values.append(e) + + return self.func("STRUCT", *values) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index e549f62..6ebae1e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -185,7 +185,6 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "CHARSET": TokenType.CHARACTER_SET, - "ENUM": TokenType.ENUM, "FORCE": TokenType.FORCE, "IGNORE": TokenType.IGNORE, "LOCK TABLES": TokenType.COMMAND, @@ -391,6 +390,11 @@ class MySQL(Dialect): "WARNINGS": _show_parser("WARNINGS"), } + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "LOCK": lambda self: self._parse_property_assignment(exp.LockProperty), + } + SET_PARSERS = { **parser.Parser.SET_PARSERS, "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), @@ -416,16 +420,11 @@ class MySQL(Dialect): "SPATIAL", } - PROFILE_TYPES = { - "ALL", - "BLOCK IO", - "CONTEXT SWITCHES", - "CPU", - "IPC", - "MEMORY", - "PAGE FAULTS", - "SOURCE", - "SWAPS", + PROFILE_TYPES: parser.OPTIONS_TYPE = { + **dict.fromkeys(("ALL", "CPU", "IPC", "MEMORY", "SOURCE", "SWAPS"), tuple()), + "BLOCK": ("IO",), + "CONTEXT": ("SWITCHES",), + "PAGE": ("FAULTS",), } TYPE_TOKENS = { diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index fcb3aab..bccdad0 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -66,6 +66,26 @@ class Oracle(Dialect): "FF6": "%f", # only 6 digits are supported in python formats } + class Tokenizer(tokens.Tokenizer): + VAR_SINGLE_TOKENS = {"@", "$", "#"} + + KEYWORDS = { + **tokens.Tokenizer.KEYWORDS, + "(+)": TokenType.JOIN_MARKER, + "BINARY_DOUBLE": TokenType.DOUBLE, + "BINARY_FLOAT": TokenType.FLOAT, + "COLUMNS": TokenType.COLUMN, + "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, + "MINUS": TokenType.EXCEPT, + "NVARCHAR2": TokenType.NVARCHAR, + "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY, + "SAMPLE": TokenType.TABLE_SAMPLE, + "START": TokenType.BEGIN, + "SYSDATE": TokenType.CURRENT_TIMESTAMP, + "TOP": TokenType.TOP, + "VARCHAR2": TokenType.VARCHAR, + } + class Parser(parser.Parser): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} @@ -93,6 +113,21 @@ class Oracle(Dialect): "XMLTABLE": lambda self: self._parse_xml_table(), } + NO_PAREN_FUNCTION_PARSERS = { + **parser.Parser.NO_PAREN_FUNCTION_PARSERS, + "CONNECT_BY_ROOT": lambda self: self.expression( + exp.ConnectByRoot, this=self._parse_column() + ), + } + + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "GLOBAL": lambda self: self._match_text_seq("TEMPORARY") + and self.expression(exp.TemporaryProperty, this="GLOBAL"), + "PRIVATE": lambda self: self._match_text_seq("TEMPORARY") + and self.expression(exp.TemporaryProperty, this="PRIVATE"), + } + QUERY_MODIFIER_PARSERS = { **parser.Parser.QUERY_MODIFIER_PARSERS, TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()), @@ -190,6 +225,7 @@ class Oracle(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}", exp.DateStrToDate: lambda self, e: self.func( "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), @@ -207,6 +243,7 @@ class Oracle(Dialect): exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "), + exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY", exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, @@ -242,23 +279,3 @@ class Oracle(Dialect): if len(expression.args.get("actions", [])) > 1: return f"ADD ({actions})" return f"ADD {actions}" - - class Tokenizer(tokens.Tokenizer): - VAR_SINGLE_TOKENS = {"@", "$", "#"} - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "(+)": TokenType.JOIN_MARKER, - "BINARY_DOUBLE": TokenType.DOUBLE, - "BINARY_FLOAT": TokenType.FLOAT, - "COLUMNS": TokenType.COLUMN, - "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, - "MINUS": TokenType.EXCEPT, - "NVARCHAR2": TokenType.NVARCHAR, - "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY, - "SAMPLE": TokenType.TABLE_SAMPLE, - "START": TokenType.BEGIN, - "SYSDATE": TokenType.CURRENT_TIMESTAMP, - "TOP": TokenType.TOP, - "VARCHAR2": TokenType.VARCHAR, - } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c78f8a3..b53ae07 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -138,7 +138,9 @@ def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: def _serial_to_generated(expression: exp.Expression) -> exp.Expression: - kind = expression.args.get("kind") + if not isinstance(expression, exp.ColumnDef): + return expression + kind = expression.kind if not kind: return expression @@ -279,6 +281,7 @@ class Postgres(Dialect): "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, "OID": TokenType.OBJECT_IDENTIFIER, + "ONLY": TokenType.ONLY, "OPERATOR": TokenType.OPERATOR, "REGCLASS": TokenType.OBJECT_IDENTIFIER, "REGCOLLATION": TokenType.OBJECT_IDENTIFIER, @@ -451,6 +454,7 @@ class Postgres(Dialect): exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), exp.JSONBContains: lambda self, e: self.binary(e, "?"), + exp.ParseJSON: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.JSON)), exp.JSONPathKey: json_path_key_only_name, exp.JSONPathRoot: lambda *_: "", exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this), @@ -506,6 +510,26 @@ class Postgres(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def unnest_sql(self, expression: exp.Unnest) -> str: + if len(expression.expressions) == 1: + from sqlglot.optimizer.annotate_types import annotate_types + + this = annotate_types(expression.expressions[0]) + if this.is_type("array<json>"): + while isinstance(this, exp.Cast): + this = this.this + + arg = self.sql(exp.cast(this, exp.DataType.Type.JSON)) + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + + if expression.args.get("offset"): + self.unsupported("Unsupported JSON_ARRAY_ELEMENTS with offset") + + return f"JSON_ARRAY_ELEMENTS({arg}){alias}" + + return super().unnest_sql(expression) + def bracket_sql(self, expression: exp.Bracket) -> str: """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY.""" if isinstance(expression.this, exp.Array): diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 8429547..3649bd2 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -453,11 +453,32 @@ class Presto(Dialect): return super().bracket_sql(expression) def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions): - self.unsupported("Struct with key-value definitions is unsupported.") - return self.function_fallback_sql(expression) + from sqlglot.optimizer.annotate_types import annotate_types + + expression = annotate_types(expression) + values: t.List[str] = [] + schema: t.List[str] = [] + unknown_type = False + + for e in expression.expressions: + if isinstance(e, exp.PropertyEQ): + if e.type and e.type.is_type(exp.DataType.Type.UNKNOWN): + unknown_type = True + else: + schema.append(f"{self.sql(e, 'this')} {self.sql(e.type)}") + values.append(self.sql(e, "expression")) + else: + values.append(self.sql(e)) + + size = len(expression.expressions) - return rename_func("ROW")(self, expression) + if not size or len(schema) != size: + if unknown_type: + self.unsupported( + "Cannot convert untyped key-value definitions (try annotate_types)." + ) + return self.func("ROW", *values) + return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))" def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 2201c78..0db87ec 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -70,6 +70,8 @@ class Redshift(Postgres): "SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, transaction=True), } + SUPPORTS_IMPLICIT_UNNEST = True + def _parse_table( self, schema: bool = False, @@ -124,27 +126,6 @@ class Redshift(Postgres): self._retreat(index) return None - def _parse_query_modifiers( - self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - this = super()._parse_query_modifiers(this) - - if this: - refs = set() - - for i, join in enumerate(this.args.get("joins", [])): - refs.add( - ( - this.args["from"] if i == 0 else this.args["joins"][i - 1] - ).this.alias.lower() - ) - - table = join.this - if isinstance(table, exp.Table) and not join.args.get("on"): - if table.parts[0].name.lower() in refs: - table.replace(table.to_column()) - return this - class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] @@ -225,6 +206,18 @@ class Redshift(Postgres): RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} + def unnest_sql(self, expression: exp.Unnest) -> str: + args = expression.expressions + num_args = len(args) + + if num_args > 1: + self.unsupported(f"Unsupported number of arguments in UNNEST: {num_args}") + return "" + + arg = self.sql(seq_get(args, 0)) + alias = self.expressions(expression.args.get("alias"), key="columns") + return f"{arg} AS {alias}" if alias else arg + def with_properties(self, properties: exp.Properties) -> str: """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" return self.properties(properties, prefix=" ", suffix="") diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index c773e50..20fdfb7 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -21,7 +21,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import is_int, seq_get +from sqlglot.helper import flatten, is_int, seq_get from sqlglot.tokens import TokenType if t.TYPE_CHECKING: @@ -66,7 +66,7 @@ def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: return exp.Struct( expressions=[ - t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values) + exp.PropertyEQ(this=k, expression=v) for k, v in zip(expression.keys, expression.values) ] ) @@ -409,8 +409,16 @@ class Snowflake(Dialect): "TERSE OBJECTS": _show_parser("OBJECTS"), "TABLES": _show_parser("TABLES"), "TERSE TABLES": _show_parser("TABLES"), + "VIEWS": _show_parser("VIEWS"), + "TERSE VIEWS": _show_parser("VIEWS"), "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "IMPORTED KEYS": _show_parser("IMPORTED KEYS"), + "TERSE IMPORTED KEYS": _show_parser("IMPORTED KEYS"), + "UNIQUE KEYS": _show_parser("UNIQUE KEYS"), + "TERSE UNIQUE KEYS": _show_parser("UNIQUE KEYS"), + "SEQUENCES": _show_parser("SEQUENCES"), + "TERSE SEQUENCES": _show_parser("SEQUENCES"), "COLUMNS": _show_parser("COLUMNS"), "USERS": _show_parser("USERS"), "TERSE USERS": _show_parser("USERS"), @@ -424,11 +432,13 @@ class Snowflake(Dialect): FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] + SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"} + def _parse_colon_get_path( self: parser.Parser, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: while True: - path = self._parse_bitwise() + path = self._parse_bitwise() or self._parse_var(any_token=True) # The cast :: operator has a lower precedence than the extraction operator :, so # we rearrange the AST appropriately to avoid casting the 2nd argument of GET_PATH @@ -535,7 +545,7 @@ class Snowflake(Dialect): return table def _parse_table_parts( - self, schema: bool = False, is_db_reference: bool = False + self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False ) -> exp.Table: # https://docs.snowflake.com/en/user-guide/querying-stage if self._match(TokenType.STRING, advance=False): @@ -603,7 +613,7 @@ class Snowflake(Dialect): if self._curr: scope = self._parse_table_parts() elif self._curr: - scope_kind = "SCHEMA" if this in ("OBJECTS", "TABLES") else "TABLE" + scope_kind = "SCHEMA" if this in self.SCHEMA_KINDS else "TABLE" scope = self._parse_table_parts() return self.expression( @@ -758,10 +768,6 @@ class Snowflake(Dialect): "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.Struct: lambda self, e: self.func( - "OBJECT_CONSTRUCT", - *(arg for expression in e.expressions for arg in expression.flatten()), - ), exp.Stuff: rename_func("INSERT"), exp.TimestampDiff: lambda self, e: self.func( "TIMESTAMPDIFF", e.unit, e.expression, e.this @@ -937,3 +943,19 @@ class Snowflake(Dialect): def cluster_sql(self, expression: exp.Cluster) -> str: return f"CLUSTER BY ({self.expressions(expression, flat=True)})" + + def struct_sql(self, expression: exp.Struct) -> str: + keys = [] + values = [] + + for i, e in enumerate(expression.expressions): + if isinstance(e, exp.PropertyEQ): + keys.append( + exp.Literal.string(e.name) if isinstance(e.this, exp.Identifier) else e.this + ) + values.append(e.expression) + else: + keys.append(exp.Literal.string(f"_{i}")) + values.append(e) + + return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values))) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 60cf8e1..63eae6e 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -263,14 +263,9 @@ class Spark2(Hive): CREATE_FUNCTION_RETURN_AS = False def struct_sql(self, expression: exp.Struct) -> str: - args = [] - for arg in expression.expressions: - if isinstance(arg, self.KEY_VALUE_DEFINITIONS): - args.append(exp.alias_(arg.expression, arg.this.name)) - else: - args.append(arg) - - return self.func("STRUCT", *args) + from sqlglot.generator import Generator + + return Generator.struct_sql(self, expression) def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 6596c5b..2b17ff9 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -92,6 +92,7 @@ class SQLite(Dialect): NVL2_SUPPORTED = False JSON_PATH_BRACKETED_KEY_SUPPORTED = False SUPPORTS_CREATE_TABLE_LIKE = False + SUPPORTS_TABLE_ALIAS_COLUMNS = False SUPPORTED_JSON_PATH_PARTS = { exp.JSONPathKey, @@ -173,6 +174,21 @@ class SQLite(Dialect): return super().cast_sql(expression) + def generateseries_sql(self, expression: exp.GenerateSeries) -> str: + parent = expression.parent + alias = parent and parent.args.get("alias") + + if isinstance(alias, exp.TableAlias) and alias.columns: + column_alias = alias.columns[0] + alias.set("columns", None) + sql = self.sql( + exp.select(exp.alias_("value", column_alias)).from_(expression).subquery() + ) + else: + sql = super().generateseries_sql(expression) + + return sql + def datediff_sql(self, expression: exp.DateDiff) -> str: unit = expression.args.get("unit") unit = unit.name.upper() if unit else "DAY" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 5955352..b6f491f 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, trim_sql, ) -from sqlglot.expressions import DataType from sqlglot.helper import seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -63,6 +62,44 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1) BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias} +# Unsupported options: +# - OPTIMIZE FOR ( @variable_name { UNKNOWN | = <literal_constant> } [ , ...n ] ) +# - TABLE HINT +OPTIONS: parser.OPTIONS_TYPE = { + **dict.fromkeys( + ( + "DISABLE_OPTIMIZED_PLAN_FORCING", + "FAST", + "IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX", + "LABEL", + "MAXDOP", + "MAXRECURSION", + "MAX_GRANT_PERCENT", + "MIN_GRANT_PERCENT", + "NO_PERFORMANCE_SPOOL", + "QUERYTRACEON", + "RECOMPILE", + ), + tuple(), + ), + "CONCAT": ("UNION",), + "DISABLE": ("EXTERNALPUSHDOWN", "SCALEOUTEXECUTION"), + "EXPAND": ("VIEWS",), + "FORCE": ("EXTERNALPUSHDOWN", "ORDER", "SCALEOUTEXECUTION"), + "HASH": ("GROUP", "JOIN", "UNION"), + "KEEP": ("PLAN",), + "KEEPFIXED": ("PLAN",), + "LOOP": ("JOIN",), + "MERGE": ("JOIN", "UNION"), + "OPTIMIZE": (("FOR", "UNKNOWN"),), + "ORDER": ("GROUP",), + "PARAMETERIZATION": ("FORCED", "SIMPLE"), + "ROBUST": ("PLAN",), + "USE": ("PLAN",), +} + +OPTIONS_THAT_REQUIRE_EQUAL = ("MAX_GRANT_PERCENT", "MIN_GRANT_PERCENT", "LABEL") + def _build_formatted_time( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None @@ -221,19 +258,17 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: # We keep track of the unaliased column projection indexes instead of the expressions # themselves, because the latter are going to be replaced by new nodes when the aliases # are added and hence we won't be able to reach these newly added Alias parents - subqueryable = expression.this + query = expression.this unaliased_column_indexes = ( - i - for i, c in enumerate(subqueryable.selects) - if isinstance(c, exp.Column) and not c.alias + i for i, c in enumerate(query.selects) if isinstance(c, exp.Column) and not c.alias ) - qualify_outputs(subqueryable) + qualify_outputs(query) # Preserve the quoting information of columns for newly added Alias nodes - subqueryable_selects = subqueryable.selects + query_selects = query.selects for select_index in unaliased_column_indexes: - alias = subqueryable_selects[select_index] + alias = query_selects[select_index] column = alias.this if isinstance(column.this, exp.Identifier): alias.args["alias"].set("quoted", column.this.quoted) @@ -420,7 +455,6 @@ class TSQL(Dialect): "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "NTEXT": TokenType.TEXT, - "NVARCHAR(MAX)": TokenType.TEXT, "PRINT": TokenType.COMMAND, "PROC": TokenType.PROCEDURE, "REAL": TokenType.FLOAT, @@ -431,15 +465,24 @@ class TSQL(Dialect): "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "UPDATE STATISTICS": TokenType.COMMAND, - "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, + "OPTION": TokenType.OPTION, } class Parser(parser.Parser): SET_REQUIRES_ASSIGNMENT_DELIMITER = False + LOG_DEFAULTS_TO_LN = True + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False + STRING_ALIASES = True + NO_PAREN_IF_COMMANDS = False + + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + TokenType.OPTION: lambda self: ("options", self._parse_options()), + } FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -472,19 +515,7 @@ class TSQL(Dialect): "TIMEFROMPARTS": _build_timefromparts, } - JOIN_HINTS = { - "LOOP", - "HASH", - "MERGE", - "REMOTE", - } - - VAR_LENGTH_DATATYPES = { - DataType.Type.NVARCHAR, - DataType.Type.VARCHAR, - DataType.Type.CHAR, - DataType.Type.NCHAR, - } + JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"} RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { TokenType.TABLE, @@ -496,11 +527,21 @@ class TSQL(Dialect): TokenType.END: lambda self: self._parse_command(), } - LOG_DEFAULTS_TO_LN = True + def _parse_options(self) -> t.Optional[t.List[exp.Expression]]: + if not self._match(TokenType.OPTION): + return None - ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False - STRING_ALIASES = True - NO_PAREN_IF_COMMANDS = False + def _parse_option() -> t.Optional[exp.Expression]: + option = self._parse_var_from_options(OPTIONS) + if not option: + return None + + self._match(TokenType.EQ) + return self.expression( + exp.QueryOption, this=option, expression=self._parse_primary_or_var() + ) + + return self._parse_wrapped_csv(_parse_option) def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -576,48 +617,13 @@ class TSQL(Dialect): def _parse_convert( self, strict: bool, safe: t.Optional[bool] = None ) -> t.Optional[exp.Expression]: - to = self._parse_types() + this = self._parse_types() self._match(TokenType.COMMA) - this = self._parse_conjunction() - - if not to or not this: - return None - - # Retrieve length of datatype and override to default if not specified - if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: - to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) - - # Check whether a conversion with format is applicable - if self._match(TokenType.COMMA): - format_val = self._parse_number() - format_val_name = format_val.name if format_val else "" - - if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING: - raise ValueError( - f"CONVERT function at T-SQL does not support format style {format_val_name}" - ) - - format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name]) - - # Check whether the convert entails a string to date format - if to.this == DataType.Type.DATE: - return self.expression(exp.StrToDate, this=this, format=format_norm) - # Check whether the convert entails a string to datetime format - elif to.this == DataType.Type.DATETIME: - return self.expression(exp.StrToTime, this=this, format=format_norm) - # Check whether the convert entails a date to string format - elif to.this in self.VAR_LENGTH_DATATYPES: - return self.expression( - exp.Cast if strict else exp.TryCast, - to=to, - this=self.expression(exp.TimeToStr, this=this, format=format_norm), - safe=safe, - ) - elif to.this == DataType.Type.TEXT: - return self.expression(exp.TimeToStr, this=this, format=format_norm) - - # Entails a simple cast without any format requirement - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe) + args = [this, *self._parse_csv(self._parse_conjunction)] + convert = exp.Convert.from_arg_list(args) + convert.set("safe", safe) + convert.set("strict", strict) + return convert def _parse_user_defined_function( self, kind: t.Optional[TokenType] = None @@ -683,6 +689,26 @@ class TSQL(Dialect): return self.expression(exp.UniqueColumnConstraint, this=this) + def _parse_partition(self) -> t.Optional[exp.Partition]: + if not self._match_text_seq("WITH", "(", "PARTITIONS"): + return None + + def parse_range(): + low = self._parse_bitwise() + high = self._parse_bitwise() if self._match_text_seq("TO") else None + + return ( + self.expression(exp.PartitionRange, this=low, expression=high) if high else low + ) + + partition = self.expression( + exp.Partition, expressions=self._parse_wrapped_csv(parse_range) + ) + + self._match_r_paren() + + return partition + class Generator(generator.Generator): LIMIT_IS_TOP = True QUERY_HINTS = False @@ -728,6 +754,9 @@ class TSQL(Dialect): exp.DataType.Type.VARIANT: "SQL_VARIANT", } + TYPE_MAPPING.pop(exp.DataType.Type.NCHAR) + TYPE_MAPPING.pop(exp.DataType.Type.NVARCHAR) + TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, @@ -779,6 +808,20 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def convert_sql(self, expression: exp.Convert) -> str: + name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT" + return self.func( + name, expression.this, expression.expression, expression.args.get("style") + ) + + def queryoption_sql(self, expression: exp.QueryOption) -> str: + option = self.sql(expression, "this") + value = self.sql(expression, "expression") + if value: + optional_equal_sign = "= " if option in OPTIONS_THAT_REQUIRE_EQUAL else "" + return f"{option} {optional_equal_sign}{value}" + return option + def lateral_op(self, expression: exp.Lateral) -> str: cross_apply = expression.args.get("cross_apply") if cross_apply is True: @@ -876,11 +919,10 @@ class TSQL(Dialect): if ctas_with: ctas_with = ctas_with.pop() - subquery = ctas_expression - if isinstance(subquery, exp.Subqueryable): - subquery = subquery.subquery() + if isinstance(ctas_expression, exp.UNWRAPPED_QUERIES): + ctas_expression = ctas_expression.subquery() - select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True)) + select_into = exp.select("*").from_(exp.alias_(ctas_expression, "temp", table=True)) select_into.set("into", exp.Into(this=table)) select_into.set("with", ctas_with) @@ -993,3 +1035,6 @@ class TSQL(Dialect): this_sql = self.sql(this) expression_sql = self.sql(expression, "expression") return self.func(name, this_sql, expression_sql if expression_sql else None) + + def partition_sql(self, expression: exp.Partition) -> str: + return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))" diff --git a/sqlglot/diff.py b/sqlglot/diff.py index c10d640..bda9136 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -119,13 +119,18 @@ def diff( return ChangeDistiller(**kwargs).diff(source_copy, target_copy, matchings=matchings_copy) -LEAF_EXPRESSION_TYPES = ( +# The expression types for which Update edits are allowed. +UPDATABLE_EXPRESSION_TYPES = ( exp.Boolean, exp.DataType, - exp.Identifier, exp.Literal, + exp.Table, + exp.Column, + exp.Lambda, ) +IGNORED_LEAF_EXPRESSION_TYPES = (exp.Identifier,) + class ChangeDistiller: """ @@ -152,8 +157,16 @@ class ChangeDistiller: self._source = source self._target = target - self._source_index = {id(n): n for n, *_ in self._source.bfs()} - self._target_index = {id(n): n for n, *_ in self._target.bfs()} + self._source_index = { + id(n): n + for n, *_ in self._source.bfs() + if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) + } + self._target_index = { + id(n): n + for n, *_ in self._target.bfs() + if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) + } self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} @@ -170,7 +183,10 @@ class ChangeDistiller: for kept_source_node_id, kept_target_node_id in matching_set: source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] - if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node: + if ( + not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) + or source_node == target_node + ): edit_script.extend( self._generate_move_edits(source_node, target_node, matching_set) ) @@ -307,17 +323,16 @@ def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: has_child_exprs = False for _, node in expression.iter_expressions(): - has_child_exprs = True - yield from _get_leaves(node) + if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES): + has_child_exprs = True + yield from _get_leaves(node) if not has_child_exprs: yield expression def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: - if type(source) is type(target) and ( - not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent) - ): + if type(source) is type(target): if isinstance(source, exp.Join): return source.args.get("side") == target.args.get("side") @@ -343,7 +358,11 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: if expression: for a in expression.args.values(): args.extend(ensure_list(a)) - return [a for a in args if isinstance(a, exp.Expression)] + return [ + a + for a in args + if isinstance(a, exp.Expression) and not isinstance(a, IGNORED_LEAF_EXPRESSION_TYPES) + ] def _lcs( diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index e4c4040..a411c18 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -78,7 +78,7 @@ class Context: def sort(self, key) -> None: def sort_key(row: t.Tuple) -> t.Tuple: self.set_row(row) - return self.eval_tuple(key) + return tuple((t is None, t) for t in self.eval_tuple(key)) self.table.rows.sort(key=sort_key) diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index c0becbe..a2b23d4 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -142,7 +142,6 @@ class PythonExecutor: context = self.context({alias: table}) yield context types = [] - for row in reader: if not types: for v in row: @@ -150,7 +149,11 @@ class PythonExecutor: types.append(type(ast.literal_eval(v))) except (ValueError, SyntaxError): types.append(str) - context.set_row(tuple(t(v) for t, v in zip(types, row))) + + # We can't cast empty values ('') to non-string types, so we convert them to None instead + context.set_row( + tuple(None if (t is not str and v == "") else t(v) for t, v in zip(types, row)) + ) yield context.table.reader def join(self, step, context): diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 1408d3c..1a24875 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -548,12 +548,10 @@ class Expression(metaclass=_Expression): return new_node @t.overload - def replace(self, expression: E) -> E: - ... + def replace(self, expression: E) -> E: ... @t.overload - def replace(self, expression: None) -> None: - ... + def replace(self, expression: None) -> None: ... def replace(self, expression): """ @@ -913,14 +911,142 @@ class Predicate(Condition): class DerivedTable(Expression): @property def selects(self) -> t.List[Expression]: - return self.this.selects if isinstance(self.this, Subqueryable) else [] + return self.this.selects if isinstance(self.this, Query) else [] @property def named_selects(self) -> t.List[str]: return [select.output_name for select in self.selects] -class Unionable(Expression): +class Query(Expression): + def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery: + """ + Returns a `Subquery` that wraps around this query. + + Example: + >>> subquery = Select().select("x").from_("tbl").subquery() + >>> Select().select("x").from_(subquery).sql() + 'SELECT x FROM (SELECT x FROM tbl)' + + Args: + alias: an optional alias for the subquery. + copy: if `False`, modify this expression instance in-place. + """ + instance = maybe_copy(self, copy) + if not isinstance(alias, Expression): + alias = TableAlias(this=to_identifier(alias)) if alias else None + + return Subquery(this=instance, alias=alias) + + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: + """ + Adds a LIMIT clause to this query. + + Example: + >>> select("1").union(select("1")).limit(1).sql() + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' + + Args: + expression: the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, it will be used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + A limited Select expression. + """ + return ( + select("*") + .from_(self.subquery(alias="_l_0", copy=copy)) + .limit(expression, dialect=dialect, copy=False, **opts) + ) + + @property + def ctes(self) -> t.List[CTE]: + """Returns a list of all the CTEs attached to this query.""" + with_ = self.args.get("with") + return with_.expressions if with_ else [] + + @property + def selects(self) -> t.List[Expression]: + """Returns the query's projections.""" + raise NotImplementedError("Query objects must implement `selects`") + + @property + def named_selects(self) -> t.List[str]: + """Returns the output names of the query's projections.""" + raise NotImplementedError("Query objects must implement `named_selects`") + + def select( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Query: + """ + Append to or set the SELECT expressions. + + Example: + >>> Select().select("x", "y").sql() + 'SELECT x, y' + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified Query expression. + """ + raise NotImplementedError("Query objects must implement `select`") + + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Query: + """ + Append to or set the common table expressions. + + Example: + >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql() + 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts + ) + def union( self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts ) -> Union: @@ -946,7 +1072,7 @@ class Unionable(Expression): def intersect( self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts - ) -> Unionable: + ) -> Intersect: """ Builds an INTERSECT expression. @@ -969,7 +1095,7 @@ class Unionable(Expression): def except_( self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts - ) -> Unionable: + ) -> Except: """ Builds an EXCEPT expression. @@ -991,7 +1117,7 @@ class Unionable(Expression): return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) -class UDTF(DerivedTable, Unionable): +class UDTF(DerivedTable): @property def selects(self) -> t.List[Expression]: alias = self.args.get("alias") @@ -1017,23 +1143,23 @@ class Refresh(Expression): class DDL(Expression): @property - def ctes(self): + def ctes(self) -> t.List[CTE]: + """Returns a list of all the CTEs attached to this statement.""" with_ = self.args.get("with") - if not with_: - return [] - return with_.expressions + return with_.expressions if with_ else [] @property - def named_selects(self) -> t.List[str]: - if isinstance(self.expression, Subqueryable): - return self.expression.named_selects - return [] + def selects(self) -> t.List[Expression]: + """If this statement contains a query (e.g. a CTAS), this returns the query's projections.""" + return self.expression.selects if isinstance(self.expression, Query) else [] @property - def selects(self) -> t.List[Expression]: - if isinstance(self.expression, Subqueryable): - return self.expression.selects - return [] + def named_selects(self) -> t.List[str]: + """ + If this statement contains a query (e.g. a CTAS), this returns the output + names of the query's projections. + """ + return self.expression.named_selects if isinstance(self.expression, Query) else [] class DML(Expression): @@ -1096,6 +1222,19 @@ class Create(DDL): return kind and kind.upper() +class TruncateTable(Expression): + arg_types = { + "expressions": True, + "is_database": False, + "exists": False, + "only": False, + "cluster": False, + "identity": False, + "option": False, + "partition": False, + } + + # https://docs.snowflake.com/en/sql-reference/sql/create-clone # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy @@ -1271,6 +1410,10 @@ class ColumnDef(Expression): def constraints(self) -> t.List[ColumnConstraint]: return self.args.get("constraints") or [] + @property + def kind(self) -> t.Optional[DataType]: + return self.args.get("kind") + class AlterColumn(Expression): arg_types = { @@ -1367,7 +1510,7 @@ class CharacterSetColumnConstraint(ColumnConstraintKind): class CheckColumnConstraint(ColumnConstraintKind): - pass + arg_types = {"this": True, "enforced": False} class ClusteredColumnConstraint(ColumnConstraintKind): @@ -1776,6 +1919,10 @@ class Partition(Expression): arg_types = {"expressions": True} +class PartitionRange(Expression): + arg_types = {"this": True, "expression": True} + + class Fetch(Expression): arg_types = { "direction": False, @@ -2173,6 +2320,10 @@ class LocationProperty(Property): arg_types = {"this": True} +class LockProperty(Property): + arg_types = {"this": True} + + class LockingProperty(Property): arg_types = { "this": False, @@ -2310,7 +2461,7 @@ class StabilityProperty(Property): class TemporaryProperty(Property): - arg_types = {} + arg_types = {"this": False} class TransformModelProperty(Property): @@ -2356,6 +2507,7 @@ class Properties(Expression): "FORMAT": FileFormatProperty, "LANGUAGE": LanguageProperty, "LOCATION": LocationProperty, + "LOCK": LockProperty, "PARTITIONED_BY": PartitionedByProperty, "RETURNS": ReturnsProperty, "ROW_FORMAT": RowFormatProperty, @@ -2445,102 +2597,13 @@ class Tuple(Expression): ) -class Subqueryable(Unionable): - def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery: - """ - Convert this expression to an aliased expression that can be used as a Subquery. - - Example: - >>> subquery = Select().select("x").from_("tbl").subquery() - >>> Select().select("x").from_(subquery).sql() - 'SELECT x FROM (SELECT x FROM tbl)' - - Args: - alias (str | Identifier): an optional alias for the subquery - copy (bool): if `False`, modify this expression instance in-place. - - Returns: - Alias: the subquery - """ - instance = maybe_copy(self, copy) - if not isinstance(alias, Expression): - alias = TableAlias(this=to_identifier(alias)) if alias else None - - return Subquery(this=instance, alias=alias) - - def limit( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - raise NotImplementedError - - @property - def ctes(self): - with_ = self.args.get("with") - if not with_: - return [] - return with_.expressions - - @property - def selects(self) -> t.List[Expression]: - raise NotImplementedError("Subqueryable objects must implement `selects`") - - @property - def named_selects(self) -> t.List[str]: - raise NotImplementedError("Subqueryable objects must implement `named_selects`") - - def select( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Subqueryable: - raise NotImplementedError("Subqueryable objects must implement `select`") - - def with_( - self, - alias: ExpOrStr, - as_: ExpOrStr, - recursive: t.Optional[bool] = None, - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Subqueryable: - """ - Append to or set the common table expressions. - - Example: - >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql() - 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' - - Args: - alias: the SQL code string to parse as the table name. - If an `Expression` instance is passed, this is used as-is. - as_: the SQL code string to parse as the table expression. - If an `Expression` instance is passed, it will be used as-is. - recursive: set the RECURSIVE part of the expression. Defaults to `False`. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified expression. - """ - return _apply_cte_builder( - self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts - ) - - QUERY_MODIFIERS = { "match": False, "laterals": False, "joins": False, "connect": False, "pivots": False, + "prewhere": False, "where": False, "group": False, "having": False, @@ -2556,9 +2619,16 @@ QUERY_MODIFIERS = { "sample": False, "settings": False, "format": False, + "options": False, } +# https://learn.microsoft.com/en-us/sql/t-sql/queries/option-clause-transact-sql?view=sql-server-ver16 +# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-query?view=sql-server-ver16 +class QueryOption(Expression): + arg_types = {"this": True, "expression": False} + + # https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16 class WithTableHint(Expression): arg_types = {"expressions": True} @@ -2590,6 +2660,7 @@ class Table(Expression): "pattern": False, "ordinality": False, "when": False, + "only": False, } @property @@ -2638,7 +2709,7 @@ class Table(Expression): return col -class Union(Subqueryable): +class Union(Query): arg_types = { "with": False, "this": True, @@ -2648,34 +2719,6 @@ class Union(Subqueryable): **QUERY_MODIFIERS, } - def limit( - self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - """ - Set the LIMIT expression. - - Example: - >>> select("1").union(select("1")).limit(1).sql() - 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' - - Args: - expression: the SQL code string to parse. - This can also be an integer. - If a `Limit` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The limited subqueryable. - """ - return ( - select("*") - .from_(self.subquery(alias="_l_0", copy=copy)) - .limit(expression, dialect=dialect, copy=False, **opts) - ) - def select( self, *expressions: t.Optional[ExpOrStr], @@ -2684,26 +2727,7 @@ class Union(Subqueryable): copy: bool = True, **opts, ) -> Union: - """Append to or set the SELECT of the union recursively. - - Example: - >>> from sqlglot import parse_one - >>> parse_one("select a from x union select a from y union select a from z").select("b").sql() - 'SELECT a, b FROM x UNION SELECT a, b FROM y UNION SELECT a, b FROM z' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Union: the modified expression. - """ - this = self.copy() if copy else self + this = maybe_copy(self, copy) this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts) this.expression.unnest().select( *expressions, append=append, dialect=dialect, copy=False, **opts @@ -2800,7 +2824,7 @@ class Lock(Expression): arg_types = {"update": True, "expressions": False, "wait": False} -class Select(Subqueryable): +class Select(Query): arg_types = { "with": False, "kind": False, @@ -3011,25 +3035,6 @@ class Select(Subqueryable): def limit( self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts ) -> Select: - """ - Set the LIMIT expression. - - Example: - >>> Select().from_("tbl").select("x").limit(10).sql() - 'SELECT x FROM tbl LIMIT 10' - - Args: - expression: the SQL code string to parse. - This can also be an integer. - If a `Limit` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Select: the modified expression. - """ return _apply_builder( expression=expression, instance=self, @@ -3084,31 +3089,13 @@ class Select(Subqueryable): copy: bool = True, **opts, ) -> Select: - """ - Append to or set the SELECT expressions. - - Example: - >>> Select().select("x", "y").sql() - 'SELECT x, y' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ return _apply_list_builder( *expressions, instance=self, arg="expressions", append=append, dialect=dialect, + into=Expression, copy=copy, **opts, ) @@ -3416,12 +3403,8 @@ class Select(Subqueryable): The new Create expression. """ instance = maybe_copy(self, copy) - table_expression = maybe_parse( - table, - into=Table, - dialect=dialect, - **opts, - ) + table_expression = maybe_parse(table, into=Table, dialect=dialect, **opts) + properties_expression = None if properties: properties_expression = Properties.from_dict(properties) @@ -3493,7 +3476,10 @@ class Select(Subqueryable): return self.expressions -class Subquery(DerivedTable, Unionable): +UNWRAPPED_QUERIES = (Select, Union) + + +class Subquery(DerivedTable, Query): arg_types = { "this": True, "alias": False, @@ -3502,9 +3488,7 @@ class Subquery(DerivedTable, Unionable): } def unnest(self): - """ - Returns the first non subquery. - """ + """Returns the first non subquery.""" expression = self while isinstance(expression, Subquery): expression = expression.this @@ -3516,6 +3500,18 @@ class Subquery(DerivedTable, Unionable): expression = t.cast(Subquery, expression.parent) return expression + def select( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Subquery: + this = maybe_copy(self, copy) + this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts) + return this + @property def is_wrapper(self) -> bool: """ @@ -3603,6 +3599,10 @@ class WindowSpec(Expression): } +class PreWhere(Expression): + pass + + class Where(Expression): pass @@ -3646,6 +3646,10 @@ class Boolean(Condition): class DataTypeParam(Expression): arg_types = {"this": True, "expression": False} + @property + def name(self) -> str: + return self.this.name + class DataType(Expression): arg_types = { @@ -3926,11 +3930,17 @@ class Rollback(Expression): class AlterTable(Expression): - arg_types = {"this": True, "actions": True, "exists": False, "only": False} + arg_types = { + "this": True, + "actions": True, + "exists": False, + "only": False, + "options": False, + } class AddConstraint(Expression): - arg_types = {"this": False, "expression": False, "enforced": False} + arg_types = {"expressions": True} class DropPartition(Expression): @@ -3996,6 +4006,10 @@ class Overlaps(Binary): class Dot(Binary): @property + def is_star(self) -> bool: + return self.expression.is_star + + @property def name(self) -> str: return self.expression.name @@ -4390,6 +4404,10 @@ class Anonymous(Func): arg_types = {"this": True, "expressions": False} is_var_len_args = True + @property + def name(self) -> str: + return self.this if isinstance(self.this, str) else self.this.name + class AnonymousAggFunc(AggFunc): arg_types = {"this": True, "expressions": False} @@ -4433,8 +4451,13 @@ class ToChar(Func): arg_types = {"this": True, "format": False, "nlsparam": False} +# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax +class Convert(Func): + arg_types = {"this": True, "expression": True, "style": False} + + class GenerateSeries(Func): - arg_types = {"start": True, "end": True, "step": False} + arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False} class ArrayAgg(AggFunc): @@ -4624,6 +4647,11 @@ class ConcatWs(Concat): _sql_names = ["CONCAT_WS"] +# https://docs.oracle.com/cd/B13789_01/server.101/b10759/operators004.htm#i1035022 +class ConnectByRoot(Func): + pass + + class Count(AggFunc): arg_types = {"this": False, "expressions": False} is_var_len_args = True @@ -5197,6 +5225,10 @@ class Month(Func): pass +class AddMonths(Func): + arg_types = {"this": True, "expression": True} + + class Nvl2(Func): arg_types = {"this": True, "true": True, "false": False} @@ -5313,6 +5345,10 @@ class SHA2(Func): arg_types = {"this": True, "length": False} +class Sign(Func): + _sql_names = ["SIGN", "SIGNUM"] + + class SortArray(Func): arg_types = {"this": True, "asc": False} @@ -5554,7 +5590,13 @@ class Use(Expression): class Merge(Expression): - arg_types = {"this": True, "using": True, "on": True, "expressions": True, "with": False} + arg_types = { + "this": True, + "using": True, + "on": True, + "expressions": True, + "with": False, + } class When(Func): @@ -5587,8 +5629,7 @@ def maybe_parse( prefix: t.Optional[str] = None, copy: bool = False, **opts, -) -> E: - ... +) -> E: ... @t.overload @@ -5600,8 +5641,7 @@ def maybe_parse( prefix: t.Optional[str] = None, copy: bool = False, **opts, -) -> E: - ... +) -> E: ... def maybe_parse( @@ -5653,13 +5693,11 @@ def maybe_parse( @t.overload -def maybe_copy(instance: None, copy: bool = True) -> None: - ... +def maybe_copy(instance: None, copy: bool = True) -> None: ... @t.overload -def maybe_copy(instance: E, copy: bool = True) -> E: - ... +def maybe_copy(instance: E, copy: bool = True) -> E: ... def maybe_copy(instance, copy=True): @@ -6282,15 +6320,13 @@ SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$") @t.overload -def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: - ... +def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ... @t.overload def to_identifier( name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True -) -> Identifier: - ... +) -> Identifier: ... def to_identifier(name, quoted=None, copy=True): @@ -6362,13 +6398,11 @@ def to_interval(interval: str | Literal) -> Interval: @t.overload -def to_table(sql_path: str | Table, **kwargs) -> Table: - ... +def to_table(sql_path: str | Table, **kwargs) -> Table: ... @t.overload -def to_table(sql_path: None, **kwargs) -> None: - ... +def to_table(sql_path: None, **kwargs) -> None: ... def to_table( @@ -6929,7 +6963,7 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: if isinstance(node, Placeholder): if node.name: new_name = kwargs.get(node.name) - if new_name: + if new_name is not None: return convert(new_name) else: try: @@ -6943,7 +6977,7 @@ def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: def expand( expression: Expression, - sources: t.Dict[str, Subqueryable], + sources: t.Dict[str, Query], dialect: DialectType = None, copy: bool = True, ) -> Expression: @@ -6959,7 +6993,7 @@ def expand( Args: expression: The expression to expand. - sources: A dictionary of name to Subqueryables. + sources: A dictionary of name to Queries. dialect: The dialect of the sources dict. copy: Whether to copy the expression during transformation. Defaults to True. diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 4bb5005..e6f5c4b 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -73,17 +73,16 @@ class Generator(metaclass=_Generator): TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { **JSON_PATH_PART_TRANSFORMS, exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", - exp.CaseSpecificColumnConstraint: lambda self, + exp.CaseSpecificColumnConstraint: lambda _, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", - exp.CheckColumnConstraint: lambda self, e: f"CHECK ({self.sql(e, 'this')})", exp.ClusteredColumnConstraint: lambda self, e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", - exp.CopyGrantsProperty: lambda self, e: "COPY GRANTS", + exp.CopyGrantsProperty: lambda *_: "COPY GRANTS", exp.DateAdd: lambda self, e: self.func( "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) ), @@ -91,8 +90,8 @@ class Generator(metaclass=_Generator): exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), - exp.ExternalProperty: lambda self, e: "EXTERNAL", - exp.HeapProperty: lambda self, e: "HEAP", + exp.ExternalProperty: lambda *_: "EXTERNAL", + exp.HeapProperty: lambda *_: "HEAP", exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", @@ -105,13 +104,13 @@ class Generator(metaclass=_Generator): ), exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), - exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", - exp.MaterializedProperty: lambda self, e: "MATERIALIZED", + exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG", + exp.MaterializedProperty: lambda *_: "MATERIALIZED", exp.NonClusteredColumnConstraint: lambda self, e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", - exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", - exp.NotForReplicationColumnConstraint: lambda self, e: "NOT FOR REPLICATION", - exp.OnCommitProperty: lambda self, + exp.NoPrimaryIndexProperty: lambda *_: "NO PRIMARY INDEX", + exp.NotForReplicationColumnConstraint: lambda *_: "NOT FOR REPLICATION", + exp.OnCommitProperty: lambda _, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", @@ -122,21 +121,21 @@ class Generator(metaclass=_Generator): exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", exp.SetConfigProperty: lambda self, e: self.sql(e, "this"), - exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", + exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET", exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", - exp.SqlReadWriteProperty: lambda self, e: e.name, - exp.SqlSecurityProperty: lambda self, + exp.SqlReadWriteProperty: lambda _, e: e.name, + exp.SqlSecurityProperty: lambda _, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", - exp.StabilityProperty: lambda self, e: e.name, - exp.TemporaryProperty: lambda self, e: "TEMPORARY", + exp.StabilityProperty: lambda _, e: e.name, + exp.TemporaryProperty: lambda *_: "TEMPORARY", exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression), exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions), - exp.TransientProperty: lambda self, e: "TRANSIENT", - exp.UppercaseColumnConstraint: lambda self, e: "UPPERCASE", + exp.TransientProperty: lambda *_: "TRANSIENT", + exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE", exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), - exp.VolatileProperty: lambda self, e: "VOLATILE", + exp.VolatileProperty: lambda *_: "VOLATILE", exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", } @@ -356,6 +355,7 @@ class Generator(metaclass=_Generator): STRUCT_DELIMITER = ("<", ">") PARAMETER_TOKEN = "@" + NAMED_PLACEHOLDER_TOKEN = ":" PROPERTIES_LOCATION = { exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, @@ -388,6 +388,7 @@ class Generator(metaclass=_Generator): exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, exp.LikeProperty: exp.Properties.Location.POST_SCHEMA, exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, + exp.LockProperty: exp.Properties.Location.POST_SCHEMA, exp.LockingProperty: exp.Properties.Location.POST_ALIAS, exp.LogProperty: exp.Properties.Location.POST_NAME, exp.MaterializedProperty: exp.Properties.Location.POST_CREATE, @@ -459,11 +460,16 @@ class Generator(metaclass=_Generator): exp.Paren, ) + PARAMETERIZABLE_TEXT_TYPES = { + exp.DataType.Type.NVARCHAR, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.CHAR, + exp.DataType.Type.NCHAR, + } + # Expressions that need to have all CTEs under them bubbled up to them EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set() - KEY_VALUE_DEFINITIONS = (exp.EQ, exp.PropertyEQ, exp.Slice) - SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" __slots__ = ( @@ -630,7 +636,7 @@ class Generator(metaclass=_Generator): this_sql = self.indent( ( self.sql(expression) - if isinstance(expression, (exp.Select, exp.Union)) + if isinstance(expression, exp.UNWRAPPED_QUERIES) else self.sql(expression, "this") ), level=1, @@ -1535,8 +1541,8 @@ class Generator(metaclass=_Generator): expr = self.sql(expression, "expression") return f"{this} ({kind} => {expr})" - def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: - table = ".".join( + def table_parts(self, expression: exp.Table) -> str: + return ".".join( self.sql(part) for part in ( expression.args.get("catalog"), @@ -1546,6 +1552,9 @@ class Generator(metaclass=_Generator): if part is not None ) + def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: + table = self.table_parts(expression) + only = "ONLY " if expression.args.get("only") else "" version = self.sql(expression, "version") version = f" {version}" if version else "" alias = self.sql(expression, "alias") @@ -1572,7 +1581,7 @@ class Generator(metaclass=_Generator): if when: table = f"{table} {when}" - return f"{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}" + return f"{only}{table}{version}{file_format}{alias}{hints}{pivots}{joins}{laterals}{ordinality}" def tablesample_sql( self, @@ -1681,7 +1690,7 @@ class Generator(metaclass=_Generator): alias_node = expression.args.get("alias") column_names = alias_node and alias_node.columns - selects: t.List[exp.Subqueryable] = [] + selects: t.List[exp.Query] = [] for i, tup in enumerate(expression.expressions): row = tup.expressions @@ -1697,10 +1706,8 @@ class Generator(metaclass=_Generator): # This may result in poor performance for large-cardinality `VALUES` tables, due to # the deep nesting of the resulting exp.Unions. If this is a problem, either increase # `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`. - subqueryable = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects) - return self.subquery_sql( - subqueryable.subquery(alias_node and alias_node.this, copy=False) - ) + query = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects) + return self.subquery_sql(query.subquery(alias_node and alias_node.this, copy=False)) alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else "" unions = " UNION ALL ".join(self.sql(select) for select in selects) @@ -1854,7 +1861,7 @@ class Generator(metaclass=_Generator): ] args_sql = ", ".join(self.sql(e) for e in args) - args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql + args_sql = f"({args_sql})" if top and any(not e.is_number for e in args) else args_sql expressions = self.expressions(expression, flat=True) expressions = f" BY {expressions}" if expressions else "" @@ -2070,12 +2077,17 @@ class Generator(metaclass=_Generator): else [] ) + options = self.expressions(expression, key="options") + if options: + options = f" OPTION{self.wrap(options)}" + return csv( *sqls, *[self.sql(join) for join in expression.args.get("joins") or []], self.sql(expression, "connect"), self.sql(expression, "match"), *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], + self.sql(expression, "prewhere"), self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), @@ -2083,9 +2095,13 @@ class Generator(metaclass=_Generator): self.sql(expression, "order"), *offset_limit_modifiers, *self.after_limit_modifiers(expression), + options, sep="", ) + def queryoption_sql(self, expression: exp.QueryOption) -> str: + return "" + def offset_limit_modifiers( self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] ) -> t.List[str]: @@ -2140,9 +2156,9 @@ class Generator(metaclass=_Generator): self.sql( exp.Struct( expressions=[ - exp.column(e.output_name).eq( - e.this if isinstance(e, exp.Alias) else e - ) + exp.PropertyEQ(this=e.args.get("alias"), expression=e.this) + if isinstance(e, exp.Alias) + else e for e in expression.expressions ] ) @@ -2204,7 +2220,7 @@ class Generator(metaclass=_Generator): return f"@@{kind}{this}" def placeholder_sql(self, expression: exp.Placeholder) -> str: - return f":{expression.name}" if expression.name else "?" + return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.name else "?" def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: alias = self.sql(expression, "alias") @@ -2261,6 +2277,9 @@ class Generator(metaclass=_Generator): return f"UNNEST({args}){suffix}" + def prewhere_sql(self, expression: exp.PreWhere) -> str: + return "" + def where_sql(self, expression: exp.Where) -> str: this = self.indent(self.sql(expression, "this")) return f"{self.seg('WHERE')}{self.sep()}{this}" @@ -2326,7 +2345,7 @@ class Generator(metaclass=_Generator): def any_sql(self, expression: exp.Any) -> str: this = self.sql(expression, "this") - if isinstance(expression.this, exp.Subqueryable): + if isinstance(expression.this, exp.UNWRAPPED_QUERIES): this = self.wrap(this) return f"ANY {this}" @@ -2568,7 +2587,7 @@ class Generator(metaclass=_Generator): is_global = " GLOBAL" if expression.args.get("is_global") else "" if query: - in_sql = self.wrap(query) + in_sql = self.wrap(self.sql(query)) elif unnest: in_sql = self.in_unnest_op(unnest) elif field: @@ -2610,7 +2629,7 @@ class Generator(metaclass=_Generator): return f"REFERENCES {this}{expressions}{options}" def anonymous_sql(self, expression: exp.Anonymous) -> str: - return self.func(expression.name, *expression.expressions) + return self.func(self.sql(expression, "this"), *expression.expressions) def paren_sql(self, expression: exp.Paren) -> str: if isinstance(expression.unnest(), exp.Select): @@ -2822,7 +2841,9 @@ class Generator(metaclass=_Generator): exists = " IF EXISTS" if expression.args.get("exists") else "" only = " ONLY" if expression.args.get("only") else "" - return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}" + options = self.expressions(expression, key="options") + options = f", {options}" if options else "" + return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}{options}" def add_column_sql(self, expression: exp.AlterTable) -> str: if self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD: @@ -2839,15 +2860,7 @@ class Generator(metaclass=_Generator): return f"DROP{exists}{expressions}" def addconstraint_sql(self, expression: exp.AddConstraint) -> str: - this = self.sql(expression, "this") - expression_ = self.sql(expression, "expression") - add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD" - - enforced = expression.args.get("enforced") - if enforced is not None: - return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}" - - return f"{add_constraint} {expression_}" + return f"ADD {self.expressions(expression)}" def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) @@ -3296,6 +3309,10 @@ class Generator(metaclass=_Generator): self.unsupported("Unsupported index constraint option.") return "" + def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: + enforced = " ENFORCED" if expression.args.get("enforced") else "" + return f"CHECK ({self.sql(expression, 'this')}){enforced}" + def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: kind = self.sql(expression, "kind") kind = f"{kind} INDEX" if kind else "INDEX" @@ -3452,9 +3469,87 @@ class Generator(metaclass=_Generator): return expression - def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.Expression]: - return [ - exp.func("COALESCE", exp.cast(value, "text"), exp.Literal.string("")) - for value in values - if value - ] + def generateseries_sql(self, expression: exp.GenerateSeries) -> str: + expression.set("is_end_exclusive", None) + return self.function_fallback_sql(expression) + + def struct_sql(self, expression: exp.Struct) -> str: + expression.set( + "expressions", + [ + exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e + for e in expression.expressions + ], + ) + + return self.function_fallback_sql(expression) + + def partitionrange_sql(self, expression: exp.PartitionRange) -> str: + low = self.sql(expression, "this") + high = self.sql(expression, "expression") + + return f"{low} TO {high}" + + def truncatetable_sql(self, expression: exp.TruncateTable) -> str: + target = "DATABASE" if expression.args.get("is_database") else "TABLE" + tables = f" {self.expressions(expression)}" + + exists = " IF EXISTS" if expression.args.get("exists") else "" + + on_cluster = self.sql(expression, "cluster") + on_cluster = f" {on_cluster}" if on_cluster else "" + + identity = self.sql(expression, "identity") + identity = f" {identity} IDENTITY" if identity else "" + + option = self.sql(expression, "option") + option = f" {option}" if option else "" + + partition = self.sql(expression, "partition") + partition = f" {partition}" if partition else "" + + return f"TRUNCATE {target}{exists}{tables}{on_cluster}{identity}{option}{partition}" + + # This transpiles T-SQL's CONVERT function + # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16 + def convert_sql(self, expression: exp.Convert) -> str: + to = expression.this + value = expression.expression + style = expression.args.get("style") + safe = expression.args.get("safe") + strict = expression.args.get("strict") + + if not to or not value: + return "" + + # Retrieve length of datatype and override to default if not specified + if not seq_get(to.expressions, 0) and to.this in self.PARAMETERIZABLE_TEXT_TYPES: + to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) + + transformed: t.Optional[exp.Expression] = None + cast = exp.Cast if strict else exp.TryCast + + # Check whether a conversion with format (T-SQL calls this 'style') is applicable + if isinstance(style, exp.Literal) and style.is_int: + from sqlglot.dialects.tsql import TSQL + + style_value = style.name + converted_style = TSQL.CONVERT_FORMAT_MAPPING.get(style_value) + if not converted_style: + self.unsupported(f"Unsupported T-SQL 'style' value: {style_value}") + + fmt = exp.Literal.string(converted_style) + + if to.this == exp.DataType.Type.DATE: + transformed = exp.StrToDate(this=value, format=fmt) + elif to.this == exp.DataType.Type.DATETIME: + transformed = exp.StrToTime(this=value, format=fmt) + elif to.this in self.PARAMETERIZABLE_TEXT_TYPES: + transformed = cast(this=exp.TimeToStr(this=value, format=fmt), to=to, safe=safe) + elif to.this == exp.DataType.Type.TEXT: + transformed = exp.TimeToStr(this=value, format=fmt) + + if not transformed: + transformed = cast(this=value, to=to, safe=safe) + + return self.sql(transformed) diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 35a4586..0d4547f 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -53,13 +53,11 @@ def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: @t.overload -def ensure_list(value: t.Collection[T]) -> t.List[T]: - ... +def ensure_list(value: t.Collection[T]) -> t.List[T]: ... @t.overload -def ensure_list(value: T) -> t.List[T]: - ... +def ensure_list(value: T) -> t.List[T]: ... def ensure_list(value): @@ -81,13 +79,11 @@ def ensure_list(value): @t.overload -def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: - ... +def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... @t.overload -def ensure_collection(value: T) -> t.Collection[T]: - ... +def ensure_collection(value: T) -> t.Collection[T]: ... def ensure_collection(value): diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index f10fbb9..eb428dc 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -1,16 +1,19 @@ from __future__ import annotations import json +import logging import typing as t from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse from sqlglot.errors import SqlglotError -from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, qualify +from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType +logger = logging.getLogger("sqlglot") + @dataclass(frozen=True) class Node: @@ -18,7 +21,8 @@ class Node: expression: exp.Expression source: exp.Expression downstream: t.List[Node] = field(default_factory=list) - alias: str = "" + source_name: str = "" + reference_node_name: str = "" def walk(self) -> t.Iterator[Node]: yield self @@ -67,7 +71,7 @@ def lineage( column: str | exp.Column, sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, - sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, + sources: t.Optional[t.Dict[str, str | exp.Query]] = None, dialect: DialectType = None, **kwargs, ) -> Node: @@ -86,14 +90,12 @@ def lineage( """ expression = maybe_parse(sql, dialect=dialect) + column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name if sources: expression = exp.expand( expression, - { - k: t.cast(exp.Subqueryable, maybe_parse(v, dialect=dialect)) - for k, v in sources.items() - }, + {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, dialect=dialect, ) @@ -109,122 +111,141 @@ def lineage( if not scope: raise SqlglotError("Cannot build lineage, sql must be SELECT") - def to_node( - column: str | int, - scope: Scope, - scope_name: t.Optional[str] = None, - upstream: t.Optional[Node] = None, - alias: t.Optional[str] = None, - ) -> Node: - aliases = { - dt.alias: dt.comments[0].split()[1] - for dt in scope.derived_tables - if dt.comments and dt.comments[0].startswith("source: ") - } + if not any(select.alias_or_name == column for select in scope.expression.selects): + raise SqlglotError(f"Cannot find column '{column}' in query.") - # Find the specific select clause that is the source of the column we want. - # This can either be a specific, named select or a generic `*` clause. - select = ( - scope.expression.selects[column] - if isinstance(column, int) - else next( - (select for select in scope.expression.selects if select.alias_or_name == column), - exp.Star() if scope.expression.is_star else scope.expression, - ) - ) + return to_node(column, scope, dialect) - if isinstance(scope.expression, exp.Union): - upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) - - index = ( - column - if isinstance(column, int) - else next( - ( - i - for i, select in enumerate(scope.expression.selects) - if select.alias_or_name == column or select.is_star - ), - -1, # mypy will not allow a None here, but a negative index should never be returned - ) - ) - - if index == -1: - raise ValueError(f"Could not find {column} in {scope.expression}") - for s in scope.union_scopes: - to_node(index, scope=s, upstream=upstream, alias=alias) - - return upstream - - if isinstance(scope.expression, exp.Select): - # For better ergonomics in our node labels, replace the full select with - # a version that has only the column we care about. - # "x", SELECT x, y FROM foo - # => "x", SELECT x FROM foo - source = t.cast(exp.Expression, scope.expression.select(select, append=False)) - else: - source = scope.expression - - # Create the node for this step in the lineage chain, and attach it to the previous one. - node = Node( - name=f"{scope_name}.{column}" if scope_name else str(column), - source=source, - expression=select, - alias=alias or "", +def to_node( + column: str | int, + scope: Scope, + dialect: DialectType, + scope_name: t.Optional[str] = None, + upstream: t.Optional[Node] = None, + source_name: t.Optional[str] = None, + reference_node_name: t.Optional[str] = None, +) -> Node: + source_names = { + dt.alias: dt.comments[0].split()[1] + for dt in scope.derived_tables + if dt.comments and dt.comments[0].startswith("source: ") + } + + # Find the specific select clause that is the source of the column we want. + # This can either be a specific, named select or a generic `*` clause. + select = ( + scope.expression.selects[column] + if isinstance(column, int) + else next( + (select for select in scope.expression.selects if select.alias_or_name == column), + exp.Star() if scope.expression.is_star else scope.expression, ) + ) - if upstream: - upstream.downstream.append(node) + if isinstance(scope.expression, exp.Union): + upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) - subquery_scopes = { - id(subquery_scope.expression): subquery_scope - for subquery_scope in scope.subquery_scopes - } + index = ( + column + if isinstance(column, int) + else next( + ( + i + for i, select in enumerate(scope.expression.selects) + if select.alias_or_name == column or select.is_star + ), + -1, # mypy will not allow a None here, but a negative index should never be returned + ) + ) - for subquery in find_all_in_scope(select, exp.Subqueryable): - subquery_scope = subquery_scopes[id(subquery)] + if index == -1: + raise ValueError(f"Could not find {column} in {scope.expression}") + + for s in scope.union_scopes: + to_node( + index, + scope=s, + dialect=dialect, + upstream=upstream, + source_name=source_name, + reference_node_name=reference_node_name, + ) - for name in subquery.named_selects: - to_node(name, scope=subquery_scope, upstream=node) + return upstream + + if isinstance(scope.expression, exp.Select): + # For better ergonomics in our node labels, replace the full select with + # a version that has only the column we care about. + # "x", SELECT x, y FROM foo + # => "x", SELECT x FROM foo + source = t.cast(exp.Expression, scope.expression.select(select, append=False)) + else: + source = scope.expression + + # Create the node for this step in the lineage chain, and attach it to the previous one. + node = Node( + name=f"{scope_name}.{column}" if scope_name else str(column), + source=source, + expression=select, + source_name=source_name or "", + reference_node_name=reference_node_name or "", + ) - # if the select is a star add all scope sources as downstreams - if select.is_star: - for source in scope.sources.values(): - if isinstance(source, Scope): - source = source.expression - node.downstream.append(Node(name=select.sql(), source=source, expression=source)) + if upstream: + upstream.downstream.append(node) - # Find all columns that went into creating this one to list their lineage nodes. - source_columns = set(find_all_in_scope(select, exp.Column)) + subquery_scopes = { + id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes + } - # If the source is a UDTF find columns used in the UTDF to generate the table - if isinstance(source, exp.UDTF): - source_columns |= set(source.find_all(exp.Column)) + for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): + subquery_scope = subquery_scopes.get(id(subquery)) + if not subquery_scope: + logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") + continue - for c in source_columns: - table = c.table - source = scope.sources.get(table) + for name in subquery.named_selects: + to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) + # if the select is a star add all scope sources as downstreams + if select.is_star: + for source in scope.sources.values(): if isinstance(source, Scope): - # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. - to_node( - c.name, - scope=source, - scope_name=table, - upstream=node, - alias=aliases.get(table) or alias, - ) - else: - # The source is not a scope - we've reached the end of the line. At this point, if a source is not found - # it means this column's lineage is unknown. This can happen if the definition of a source used in a query - # is not passed into the `sources` map. - source = source or exp.Placeholder() - node.downstream.append(Node(name=c.sql(), source=source, expression=source)) - - return node + source = source.expression + node.downstream.append(Node(name=select.sql(), source=source, expression=source)) + + # Find all columns that went into creating this one to list their lineage nodes. + source_columns = set(find_all_in_scope(select, exp.Column)) + + # If the source is a UDTF find columns used in the UTDF to generate the table + if isinstance(source, exp.UDTF): + source_columns |= set(source.find_all(exp.Column)) + + for c in source_columns: + table = c.table + source = scope.sources.get(table) + + if isinstance(source, Scope): + selected_node, _ = scope.selected_sources.get(table, (None, None)) + # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. + to_node( + c.name, + scope=source, + dialect=dialect, + scope_name=table, + upstream=node, + source_name=source_names.get(table) or source_name, + reference_node_name=selected_node.name if selected_node else None, + ) + else: + # The source is not a scope - we've reached the end of the line. At this point, if a source is not found + # it means this column's lineage is unknown. This can happen if the definition of a source used in a query + # is not passed into the `sources` map. + source = source or exp.Placeholder() + node.downstream.append(Node(name=c.sql(), source=source, expression=source)) - return to_node(column if isinstance(column, str) else column.name, scope) + return node class GraphHTML: diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index ce274bb..81b1ee6 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -191,6 +191,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DateToDi, exp.Floor, exp.Levenshtein, + exp.Sign, exp.StrPosition, exp.TsOrDiToDi, }, @@ -262,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Div: lambda self, e: self._annotate_div(e), + exp.Dot: lambda self, e: self._annotate_dot(e), exp.Explode: lambda self, e: self._annotate_explode(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), @@ -273,15 +275,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), + exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.Timestamp: lambda self, e: self._annotate_with_type( e, exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, ), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.Unnest: lambda self, e: self._annotate_unnest(e), exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), - exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True), } NESTED_TYPES = { @@ -380,8 +384,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): source = scope.sources.get(col.table) if isinstance(source, exp.Table): self._set_type(col, self.schema.get_column_type(source, col)) - elif source and col.table in selects and col.name in selects[col.table]: - self._set_type(col, selects[col.table][col.name].type) + elif source: + if col.table in selects and col.name in selects[col.table]: + self._set_type(col, selects[col.table][col.name].type) + elif isinstance(source.expression, exp.Unnest): + self._set_type(col, source.expression.type) # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -514,7 +521,14 @@ class TypeAnnotator(metaclass=_TypeAnnotator): last_datatype = None for expr in expressions: - last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) + expr_type = expr.type + + # Stop at the first nested data type found - we don't want to _maybe_coerce nested types + if expr_type.args.get("nested"): + last_datatype = expr_type + break + + last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type) self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) @@ -594,7 +608,26 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression + def _annotate_dot(self, expression: exp.Dot) -> exp.Dot: + self._annotate_args(expression) + self._set_type(expression, None) + this_type = expression.this.type + + if this_type and this_type.is_type(exp.DataType.Type.STRUCT): + for e in this_type.expressions: + if e.name == expression.expression.name: + self._set_type(expression, e.kind) + break + + return expression + def _annotate_explode(self, expression: exp.Explode) -> exp.Explode: self._annotate_args(expression) self._set_type(expression, seq_get(expression.this.type.expressions, 0)) return expression + + def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest: + self._annotate_args(expression) + child = seq_get(expression.expressions, 0) + self._set_type(expression, child and seq_get(child.type.expressions, 0)) + return expression diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index d22a998..f2a0990 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -10,13 +10,11 @@ if t.TYPE_CHECKING: @t.overload -def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: - ... +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ... @t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: - ... +def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... def normalize_identifiers(expression, dialect=None): diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ef589c9..233ffc9 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -120,6 +120,8 @@ def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: + if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: + continue table_alias = derived_table.args.get("alias") if table_alias: table_alias.args.pop("columns", None) @@ -214,7 +216,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: table = resolver.get_table(column.name) if resolve_table and not column.table else None alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( - (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) + ( + alias_expr.find(exp.AggFunc) + and ( + column.find_ancestor(exp.AggFunc) + and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) + ) + ) if alias_expr else False ) @@ -404,7 +412,7 @@ def _expand_stars( tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif expression.is_star: + elif expression.is_star and not isinstance(expression, exp.Dot): tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -437,7 +445,7 @@ def _expand_stars( if pivot_columns: new_selections.extend( - exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) + alias(exp.column(name, table=pivot.alias), name, copy=False) for name in pivot_columns if name not in columns_to_exclude ) @@ -466,7 +474,7 @@ def _expand_stars( ) # Ensures we don't overwrite the initial selections with an empty list - if new_selections: + if new_selections and isinstance(scope.expression, exp.Select): scope.expression.set("expressions", new_selections) @@ -528,7 +536,8 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: new_selections.append(selection) - scope.expression.set("expressions", new_selections) + if isinstance(scope.expression, exp.Select): + scope.expression.set("expressions", new_selections) def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: @@ -615,7 +624,7 @@ class Resolver: node, _ = self.scope.selected_sources.get(table_name) - if isinstance(node, exp.Subqueryable): + if isinstance(node, exp.Query): while node and node.alias != table_name: node = node.parent diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index d460e81..214ac0a 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -55,8 +55,8 @@ def qualify_tables( if not table.args.get("catalog") and table.args.get("db"): table.set("catalog", catalog) - if not isinstance(expression, exp.Subqueryable): - for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Unionable)): + if not isinstance(expression, exp.Query): + for node, *_ in expression.walk(prune=lambda n, *_: isinstance(n, exp.Query)): if isinstance(node, exp.Table): _qualify(node) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 0eae979..443fa6c 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -138,7 +138,7 @@ class Scope: and _is_derived_table(node) ): self._derived_tables.append(node) - elif isinstance(node, exp.Subqueryable): + elif isinstance(node, exp.UNWRAPPED_QUERIES): self._subqueries.append(node) self._collected = True @@ -225,7 +225,7 @@ class Scope: SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery Returns: - list[exp.Subqueryable]: subqueries + list[exp.Select | exp.Union]: subqueries """ self._ensure_collected() return self._subqueries @@ -486,8 +486,8 @@ def traverse_scope(expression: exp.Expression) -> t.List[Scope]: Returns: list[Scope]: scope instances """ - if isinstance(expression, exp.Unionable) or ( - isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Unionable) + if isinstance(expression, exp.Query) or ( + isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Query) ): return list(_traverse_scope(Scope(expression))) @@ -615,7 +615,7 @@ def _is_derived_table(expression: exp.Subquery) -> bool: as it doesn't introduce a new scope. If an alias is present, it shadows all names under the Subquery, so that's one exception to this rule. """ - return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) + return bool(expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES)) def _traverse_tables(scope): @@ -786,7 +786,7 @@ def walk_in_scope(expression, bfs=True, prune=None): and _is_derived_table(node) ) or isinstance(node, exp.UDTF) - or isinstance(node, exp.Subqueryable) + or isinstance(node, exp.UNWRAPPED_QUERIES) ): crossed_scope_boundary = True diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 9ffddb5..2e43d21 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str: GEN_MAP = { exp.Add: lambda e: _binary(e, "+"), exp.And: lambda e: _binary(e, "AND"), - exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}", + exp.Anonymous: lambda e: _anonymous(e), exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", @@ -1219,6 +1219,20 @@ GEN_MAP = { } +def _anonymous(e: exp.Anonymous) -> str: + this = e.this + if isinstance(this, str): + name = this.upper() + elif isinstance(this, exp.Identifier): + name = f'"{this.name}"' if this.quoted else this.name.upper() + else: + raise ValueError( + f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." + ) + + return f"{name} {','.join(gen(e) for e in e.expressions)}" + + def _binary(e: exp.Binary, op: str) -> str: return f"{gen(e.left)} {op} {gen(e.right)}" diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index b4c7475..36d9da4 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -94,8 +94,20 @@ def unnest(select, parent_select, next_alias_name): else: _replace(predicate, join_key_not_null) + group = select.args.get("group") + + if group: + if {value.this} != set(group.expressions): + select = ( + exp.select(exp.column(value.alias, "_q")) + .from_(select.subquery("_q", copy=False), copy=False) + .group_by(exp.column(value.alias, "_q"), copy=False) + ) + else: + select = select.group_by(value.this, copy=False) + parent_select.join( - select.group_by(value.this, copy=False), + select, on=column.eq(join_key), join_type="LEFT", join_alias=alias, diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 4e7f870..49dac2e 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -17,6 +17,8 @@ if t.TYPE_CHECKING: logger = logging.getLogger("sqlglot") +OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]] + def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap: if len(args) == 1 and args[0].is_star: @@ -367,6 +369,7 @@ class Parser(metaclass=_Parser): TokenType.TEMPORARY, TokenType.TOP, TokenType.TRUE, + TokenType.TRUNCATE, TokenType.UNIQUE, TokenType.UNPIVOT, TokenType.UPDATE, @@ -435,6 +438,7 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, + TokenType.TRUNCATE, TokenType.WINDOW, TokenType.XOR, *TYPE_TOKENS, @@ -578,7 +582,7 @@ class Parser(metaclass=_Parser): exp.Column: lambda self: self._parse_column(), exp.Condition: lambda self: self._parse_conjunction(), exp.DataType: lambda self: self._parse_types(allow_identifiers=False), - exp.Expression: lambda self: self._parse_statement(), + exp.Expression: lambda self: self._parse_expression(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), exp.Having: lambda self: self._parse_having(), @@ -625,10 +629,10 @@ class Parser(metaclass=_Parser): TokenType.SET: lambda self: self._parse_set(), TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.UPDATE: lambda self: self._parse_update(), + TokenType.TRUNCATE: lambda self: self._parse_truncate_table(), TokenType.USE: lambda self: self.expression( exp.Use, - kind=self._match_texts(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA")) - and exp.var(self._prev.text), + kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False), this=self._parse_table(schema=False), ), } @@ -642,36 +646,44 @@ class Parser(metaclass=_Parser): TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()), } - PRIMARY_PARSERS = { - TokenType.STRING: lambda self, token: self.expression( - exp.Literal, this=token.text, is_string=True - ), - TokenType.NUMBER: lambda self, token: self.expression( - exp.Literal, this=token.text, is_string=False - ), - TokenType.STAR: lambda self, _: self.expression( - exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + STRING_PARSERS = { + TokenType.HEREDOC_STRING: lambda self, token: self.expression( + exp.RawString, this=token.text ), - TokenType.NULL: lambda self, _: self.expression(exp.Null), - TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), - TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), - TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), - TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), - TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), - TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), TokenType.NATIONAL_STRING: lambda self, token: self.expression( exp.National, this=token.text ), TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text), - TokenType.HEREDOC_STRING: lambda self, token: self.expression( - exp.RawString, this=token.text + TokenType.STRING: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=True ), TokenType.UNICODE_STRING: lambda self, token: self.expression( exp.UnicodeString, this=token.text, escape=self._match_text_seq("UESCAPE") and self._parse_string(), ), + } + + NUMERIC_PARSERS = { + TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), + TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), + TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), + TokenType.NUMBER: lambda self, token: self.expression( + exp.Literal, this=token.text, is_string=False + ), + } + + PRIMARY_PARSERS = { + **STRING_PARSERS, + **NUMERIC_PARSERS, + TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.NULL: lambda self, _: self.expression(exp.Null), + TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), + TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), + TokenType.STAR: lambda self, _: self.expression( + exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()} + ), } PLACEHOLDER_PARSERS = { @@ -799,7 +811,9 @@ class Parser(metaclass=_Parser): exp.CharacterSetColumnConstraint, this=self._parse_var_or_string() ), "CHECK": lambda self: self.expression( - exp.CheckColumnConstraint, this=self._parse_wrapped(self._parse_conjunction) + exp.CheckColumnConstraint, + this=self._parse_wrapped(self._parse_conjunction), + enforced=self._match_text_seq("ENFORCED"), ), "COLLATE": lambda self: self.expression( exp.CollateColumnConstraint, this=self._parse_var() @@ -873,6 +887,8 @@ class Parser(metaclass=_Parser): FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} + KEY_VALUE_DEFINITIONS = (exp.Alias, exp.EQ, exp.PropertyEQ, exp.Slice) + FUNCTION_PARSERS = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), @@ -895,6 +911,7 @@ class Parser(metaclass=_Parser): QUERY_MODIFIER_PARSERS = { TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()), + TokenType.PREWHERE: lambda self: ("prewhere", self._parse_prewhere()), TokenType.WHERE: lambda self: ("where", self._parse_where()), TokenType.GROUP_BY: lambda self: ("group", self._parse_group()), TokenType.HAVING: lambda self: ("having", self._parse_having()), @@ -934,22 +951,23 @@ class Parser(metaclass=_Parser): exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this), } - MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) - DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN} PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE} TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} - TRANSACTION_CHARACTERISTICS = { - "ISOLATION LEVEL REPEATABLE READ", - "ISOLATION LEVEL READ COMMITTED", - "ISOLATION LEVEL READ UNCOMMITTED", - "ISOLATION LEVEL SERIALIZABLE", - "READ WRITE", - "READ ONLY", + TRANSACTION_CHARACTERISTICS: OPTIONS_TYPE = { + "ISOLATION": ( + ("LEVEL", "REPEATABLE", "READ"), + ("LEVEL", "READ", "COMMITTED"), + ("LEVEL", "READ", "UNCOMITTED"), + ("LEVEL", "SERIALIZABLE"), + ), + "READ": ("WRITE", "ONLY"), } + USABLES: OPTIONS_TYPE = dict.fromkeys(("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA"), tuple()) + INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} CLONE_KEYWORDS = {"CLONE", "COPY"} @@ -1012,6 +1030,9 @@ class Parser(metaclass=_Parser): # If this is True and '(' is not found, the keyword will be treated as an identifier VALUES_FOLLOWED_BY_PAREN = True + # Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift) + SUPPORTS_IMPLICIT_UNNEST = False + __slots__ = ( "error_level", "error_message_context", @@ -2450,10 +2471,37 @@ class Parser(metaclass=_Parser): alias=self._parse_table_alias() if parse_alias else None, ) + def _implicit_unnests_to_explicit(self, this: E) -> E: + from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm + + refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name} + for i, join in enumerate(this.args.get("joins") or []): + table = join.this + normalized_table = table.copy() + normalized_table.meta["maybe_column"] = True + normalized_table = _norm(normalized_table, dialect=self.dialect) + + if isinstance(table, exp.Table) and not join.args.get("on"): + if normalized_table.parts[0].name in refs: + table_as_column = table.to_column() + unnest = exp.Unnest(expressions=[table_as_column]) + + # Table.to_column creates a parent Alias node that we want to convert to + # a TableAlias and attach to the Unnest, so it matches the parser's output + if isinstance(table.args.get("alias"), exp.TableAlias): + table_as_column.replace(table_as_column.this) + exp.alias_(unnest, None, table=[table.args["alias"].this], copy=False) + + table.replace(unnest) + + refs.add(normalized_table.alias_or_name) + + return this + def _parse_query_modifiers( self, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: - if isinstance(this, self.MODIFIABLES): + if isinstance(this, (exp.Query, exp.Table)): for join in iter(self._parse_join, None): this.append("joins", join) for lateral in iter(self._parse_lateral, None): @@ -2478,6 +2526,10 @@ class Parser(metaclass=_Parser): offset.set("expressions", limit_by_expressions) continue break + + if self.SUPPORTS_IMPLICIT_UNNEST and this and "from" in this.args: + this = self._implicit_unnests_to_explicit(this) + return this def _parse_hint(self) -> t.Optional[exp.Hint]: @@ -2803,7 +2855,9 @@ class Parser(metaclass=_Parser): or self._parse_placeholder() ) - def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table: + def _parse_table_parts( + self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False + ) -> exp.Table: catalog = None db = None table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) @@ -2817,8 +2871,20 @@ class Parser(metaclass=_Parser): else: catalog = db db = table + # "" used for tsql FROM a..b case table = self._parse_table_part(schema=schema) or "" + if ( + wildcard + and self._is_connected() + and (isinstance(table, exp.Identifier) or not table) + and self._match(TokenType.STAR) + ): + if isinstance(table, exp.Identifier): + table.args["this"] += "*" + else: + table = exp.Identifier(this="*") + if is_db_reference: catalog = db db = table @@ -2861,6 +2927,9 @@ class Parser(metaclass=_Parser): bracket = parse_bracket and self._parse_bracket(None) bracket = self.expression(exp.Table, this=bracket) if bracket else None + + only = self._match(TokenType.ONLY) + this = t.cast( exp.Expression, bracket @@ -2869,6 +2938,12 @@ class Parser(metaclass=_Parser): ), ) + if only: + this.set("only", only) + + # Postgres supports a wildcard (table) suffix operator, which is a no-op in this context + self._match_text_seq("*") + if schema: return self._parse_schema(this=this) @@ -3161,6 +3236,14 @@ class Parser(metaclass=_Parser): def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: return [agg.alias for agg in aggregations] + def _parse_prewhere(self, skip_where_token: bool = False) -> t.Optional[exp.PreWhere]: + if not skip_where_token and not self._match(TokenType.PREWHERE): + return None + + return self.expression( + exp.PreWhere, comments=self._prev_comments, this=self._parse_conjunction() + ) + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]: if not skip_where_token and not self._match(TokenType.WHERE): return None @@ -3291,8 +3374,12 @@ class Parser(metaclass=_Parser): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self, parse_method: t.Optional[t.Callable] = None) -> exp.Ordered: + def _parse_ordered( + self, parse_method: t.Optional[t.Callable] = None + ) -> t.Optional[exp.Ordered]: this = parse_method() if parse_method else self._parse_conjunction() + if not this: + return None asc = self._match(TokenType.ASC) desc = self._match(TokenType.DESC) or (asc and False) @@ -3510,7 +3597,7 @@ class Parser(metaclass=_Parser): if self._match_text_seq("DISTINCT", "FROM"): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ - return self.expression(klass, this=this, expression=self._parse_conjunction()) + return self.expression(klass, this=this, expression=self._parse_bitwise()) expression = self._parse_null() or self._parse_boolean() if not expression: @@ -3528,7 +3615,7 @@ class Parser(metaclass=_Parser): matched_l_paren = self._prev.token_type == TokenType.L_PAREN expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) - if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): + if len(expressions) == 1 and isinstance(expressions[0], exp.Query): this = self.expression(exp.In, this=this, query=expressions[0]) else: this = self.expression(exp.In, this=this, expressions=expressions) @@ -3959,7 +4046,7 @@ class Parser(metaclass=_Parser): this = self._parse_query_modifiers(seq_get(expressions, 0)) - if isinstance(this, exp.Subqueryable): + if isinstance(this, exp.UNWRAPPED_QUERIES): this = self._parse_set_operations( self._parse_subquery(this=this, parse_alias=False) ) @@ -4064,6 +4151,9 @@ class Parser(metaclass=_Parser): alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) + if alias: + args = self._kv_to_prop_eq(args) + if function and not anonymous: if "dialect" in function.__code__.co_varnames: func = function(args, dialect=self.dialect) @@ -4076,6 +4166,8 @@ class Parser(metaclass=_Parser): this = func else: + if token_type == TokenType.IDENTIFIER: + this = exp.Identifier(this=this, quoted=True) this = self.expression(exp.Anonymous, this=this, expressions=args) if isinstance(this, exp.Expression): @@ -4084,6 +4176,26 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) + def _kv_to_prop_eq(self, expressions: t.List[exp.Expression]) -> t.List[exp.Expression]: + transformed = [] + + for e in expressions: + if isinstance(e, self.KEY_VALUE_DEFINITIONS): + if isinstance(e, exp.Alias): + e = self.expression(exp.PropertyEQ, this=e.args.get("alias"), expression=e.this) + + if not isinstance(e, exp.PropertyEQ): + e = self.expression( + exp.PropertyEQ, this=exp.to_identifier(e.name), expression=e.expression + ) + + if isinstance(e.this, exp.Column): + e.this.replace(e.this.this) + + transformed.append(e) + + return transformed + def _parse_function_parameter(self) -> t.Optional[exp.Expression]: return self._parse_column_def(self._parse_id_var()) @@ -4496,7 +4608,7 @@ class Parser(metaclass=_Parser): # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs if bracket_kind == TokenType.L_BRACE: - this = self.expression(exp.Struct, expressions=expressions) + this = self.expression(exp.Struct, expressions=self._kv_to_prop_eq(expressions)) elif not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) else: @@ -4747,12 +4859,10 @@ class Parser(metaclass=_Parser): return None @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: - ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: - ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... def _parse_json_object(self, agg=False): star = self._parse_star() @@ -5140,16 +5250,16 @@ class Parser(metaclass=_Parser): return None def _parse_string(self) -> t.Optional[exp.Expression]: - if self._match_set((TokenType.STRING, TokenType.RAW_STRING)): - return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + if self._match_set(self.STRING_PARSERS): + return self.STRING_PARSERS[self._prev.token_type](self, self._prev) return self._parse_placeholder() def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True) def _parse_number(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.NUMBER): - return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) + if self._match_set(self.NUMERIC_PARSERS): + return self.NUMERIC_PARSERS[self._prev.token_type](self, self._prev) return self._parse_placeholder() def _parse_identifier(self) -> t.Optional[exp.Expression]: @@ -5182,6 +5292,9 @@ class Parser(metaclass=_Parser): def _parse_var_or_string(self) -> t.Optional[exp.Expression]: return self._parse_var() or self._parse_string() + def _parse_primary_or_var(self) -> t.Optional[exp.Expression]: + return self._parse_primary() or self._parse_var(any_token=True) + def _parse_null(self) -> t.Optional[exp.Expression]: if self._match_set(self.NULL_TOKENS): return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) @@ -5200,16 +5313,12 @@ class Parser(metaclass=_Parser): return self._parse_placeholder() def _parse_parameter(self) -> exp.Parameter: - def _parse_parameter_part() -> t.Optional[exp.Expression]: - return ( - self._parse_identifier() or self._parse_primary() or self._parse_var(any_token=True) - ) - self._match(TokenType.L_BRACE) - this = _parse_parameter_part() - expression = self._match(TokenType.COLON) and _parse_parameter_part() + this = self._parse_identifier() or self._parse_primary_or_var() + expression = self._match(TokenType.COLON) and ( + self._parse_identifier() or self._parse_primary_or_var() + ) self._match(TokenType.R_BRACE) - return self.expression(exp.Parameter, this=this, expression=expression) def _parse_placeholder(self) -> t.Optional[exp.Expression]: @@ -5376,35 +5485,15 @@ class Parser(metaclass=_Parser): exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists ) - def _parse_add_constraint(self) -> exp.AddConstraint: - this = None - kind = self._prev.token_type - - if kind == TokenType.CONSTRAINT: - this = self._parse_id_var() - - if self._match_text_seq("CHECK"): - expression = self._parse_wrapped(self._parse_conjunction) - enforced = self._match_text_seq("ENFORCED") or False - - return self.expression( - exp.AddConstraint, this=this, expression=expression, enforced=enforced - ) - - if kind == TokenType.FOREIGN_KEY or self._match(TokenType.FOREIGN_KEY): - expression = self._parse_foreign_key() - elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY): - expression = self._parse_primary_key() - else: - expression = None - - return self.expression(exp.AddConstraint, this=this, expression=expression) - def _parse_alter_table_add(self) -> t.List[exp.Expression]: index = self._index - 1 - if self._match_set(self.ADD_CONSTRAINT_TOKENS): - return self._parse_csv(self._parse_add_constraint) + if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False): + return self._parse_csv( + lambda: self.expression( + exp.AddConstraint, expressions=self._parse_csv(self._parse_constraint) + ) + ) self._retreat(index) if not self.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and self._match_text_seq("ADD"): @@ -5472,6 +5561,7 @@ class Parser(metaclass=_Parser): parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None if parser: actions = ensure_list(parser(self)) + options = self._parse_csv(self._parse_property) if not self._curr and actions: return self.expression( @@ -5480,6 +5570,7 @@ class Parser(metaclass=_Parser): exists=exists, actions=actions, only=only, + options=options, ) return self._parse_as_command(start) @@ -5610,11 +5701,34 @@ class Parser(metaclass=_Parser): return set_ - def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Var]: - for option in options: - if self._match_text_seq(*option.split(" ")): - return exp.var(option) - return None + def _parse_var_from_options( + self, options: OPTIONS_TYPE, raise_unmatched: bool = True + ) -> t.Optional[exp.Var]: + start = self._curr + if not start: + return None + + option = start.text.upper() + continuations = options.get(option) + + index = self._index + self._advance() + for keywords in continuations or []: + if isinstance(keywords, str): + keywords = (keywords,) + + if self._match_text_seq(*keywords): + option = f"{option} {' '.join(keywords)}" + break + else: + if continuations or continuations is None: + if raise_unmatched: + self.raise_error(f"Unknown option {option}") + + self._retreat(index) + return None + + return exp.var(option) def _parse_as_command(self, start: Token) -> exp.Command: while self._curr: @@ -5806,14 +5920,12 @@ class Parser(metaclass=_Parser): return True @t.overload - def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: - ... + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: ... @t.overload def _replace_columns_with_dots( self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - ... + ) -> t.Optional[exp.Expression]: ... def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): @@ -5849,3 +5961,53 @@ class Parser(metaclass=_Parser): else: column.replace(dot_or_id) return node + + def _parse_truncate_table(self) -> t.Optional[exp.TruncateTable] | exp.Expression: + start = self._prev + + # Not to be confused with TRUNCATE(number, decimals) function call + if self._match(TokenType.L_PAREN): + self._retreat(self._index - 2) + return self._parse_function() + + # Clickhouse supports TRUNCATE DATABASE as well + is_database = self._match(TokenType.DATABASE) + + self._match(TokenType.TABLE) + + exists = self._parse_exists(not_=False) + + expressions = self._parse_csv( + lambda: self._parse_table(schema=True, is_db_reference=is_database) + ) + + cluster = self._parse_on_property() if self._match(TokenType.ON) else None + + if self._match_text_seq("RESTART", "IDENTITY"): + identity = "RESTART" + elif self._match_text_seq("CONTINUE", "IDENTITY"): + identity = "CONTINUE" + else: + identity = None + + if self._match_text_seq("CASCADE") or self._match_text_seq("RESTRICT"): + option = self._prev.text + else: + option = None + + partition = self._parse_partition() + + # Fallback case + if self._curr: + return self._parse_as_command(start) + + return self.expression( + exp.TruncateTable, + expressions=expressions, + is_database=is_database, + exists=exists, + cluster=cluster, + identity=identity, + option=option, + partition=partition, + ) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 939ca18..da9df7d 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -302,6 +302,7 @@ class TokenType(AutoName): OBJECT_IDENTIFIER = auto() OFFSET = auto() ON = auto() + ONLY = auto() OPERATOR = auto() ORDER_BY = auto() ORDER_SIBLINGS_BY = auto() @@ -317,6 +318,7 @@ class TokenType(AutoName): PIVOT = auto() PLACEHOLDER = auto() PRAGMA = auto() + PREWHERE = auto() PRIMARY_KEY = auto() PROCEDURE = auto() PROPERTIES = auto() @@ -353,6 +355,7 @@ class TokenType(AutoName): TOP = auto() THEN = auto() TRUE = auto() + TRUNCATE = auto() UNCACHE = auto() UNION = auto() UNNEST = auto() @@ -370,6 +373,7 @@ class TokenType(AutoName): UNIQUE = auto() VERSION_SNAPSHOT = auto() TIMESTAMP_SNAPSHOT = auto() + OPTION = auto() _ALL_TOKEN_TYPES = list(TokenType) @@ -657,6 +661,7 @@ class Tokenizer(metaclass=_Tokenizer): "DROP": TokenType.DROP, "ELSE": TokenType.ELSE, "END": TokenType.END, + "ENUM": TokenType.ENUM, "ESCAPE": TokenType.ESCAPE, "EXCEPT": TokenType.EXCEPT, "EXECUTE": TokenType.EXECUTE, @@ -752,6 +757,7 @@ class Tokenizer(metaclass=_Tokenizer): "TEMPORARY": TokenType.TEMPORARY, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, + "TRUNCATE": TokenType.TRUNCATE, "UNION": TokenType.UNION, "UNKNOWN": TokenType.UNKNOWN, "UNNEST": TokenType.UNNEST, @@ -860,7 +866,6 @@ class Tokenizer(metaclass=_Tokenizer): "GRANT": TokenType.COMMAND, "OPTIMIZE": TokenType.COMMAND, "PREPARE": TokenType.COMMAND, - "TRUNCATE": TokenType.COMMAND, "VACUUM": TokenType.COMMAND, "USER-DEFINED": TokenType.USERDEFINED, "FOR VERSION": TokenType.VERSION_SNAPSHOT, @@ -1036,12 +1041,6 @@ class Tokenizer(metaclass=_Tokenizer): def _text(self) -> str: return self.sql[self._start : self._current] - def peek(self, i: int = 0) -> str: - i = self._current + i - if i < self.size: - return self.sql[i] - return "" - def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line @@ -1182,12 +1181,8 @@ class Tokenizer(metaclass=_Tokenizer): if self._peek.isdigit(): self._advance() elif self._peek == "." and not decimal: - after = self.peek(1) - if after.isdigit() or not after.isalpha(): - decimal = True - self._advance() - else: - return self._add(TokenType.VAR) + decimal = True + self._advance() elif self._peek in ("-", "+") and scientific == 1: scientific += 1 self._advance() diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 4777609..04c1f7b 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -547,7 +547,7 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp prop and prop.this and isinstance(prop.this, exp.Schema) - and all(isinstance(e, exp.ColumnDef) and e.args.get("kind") for e in prop.this.expressions) + and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) ): prop_this = exp.Tuple( expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] @@ -560,6 +560,22 @@ def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Exp return expression +def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: + """ + Convert struct arguments to aliases: STRUCT(1 AS y) . + """ + if isinstance(expression, exp.Struct): + expression.set( + "expressions", + [ + exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e + for e in expression.expressions + ], + ) + + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], ) -> t.Callable[[Generator, exp.Expression], str]: |